diff options
-rw-r--r-- | lab_control/connection/__init__.py | 0 | ||||
-rw-r--r-- | lab_control/connection/direct_connection.py | 20 | ||||
-rw-r--r-- | lab_control/connection/serial_connection.py | 26 | ||||
-rw-r--r-- | lab_control/jds6600.py | 38 | ||||
-rw-r--r-- | lab_control/test/jds6600_unittest.py (renamed from lab_control/test/jds6600_test.py) | 25 | ||||
-rw-r--r-- | lab_control/test/mock_jds6600_device.py | 44 | ||||
-rw-r--r-- | lab_control/test/virtual_serial_port.py | 34 |
7 files changed, 118 insertions, 69 deletions
diff --git a/lab_control/connection/__init__.py b/lab_control/connection/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lab_control/connection/__init__.py diff --git a/lab_control/connection/direct_connection.py b/lab_control/connection/direct_connection.py new file mode 100644 index 0000000..56a5cf3 --- /dev/null +++ b/lab_control/connection/direct_connection.py @@ -0,0 +1,20 @@ +class DirectConnection: + def __init__(self, requestHandler): + self.requestHandler = requestHandler + self.open = True + self.config = {} + + def configure(self, config: dict) -> None: + self.config = config + + def close(self) -> None: + self.open = False + + def send(self, request: str) -> str: + return self.requestHandler(request) + + def checkConfiguration(self) -> None: + assert self.config.get("baudrate") == 115200 + assert self.config.get("bytesize") == 8 + assert self.config.get("stopbits") == 1 + assert self.config.get("parity") == "N" diff --git a/lab_control/connection/serial_connection.py b/lab_control/connection/serial_connection.py new file mode 100644 index 0000000..282494b --- /dev/null +++ b/lab_control/connection/serial_connection.py @@ -0,0 +1,26 @@ +import termios +import serial + +class SerialConnection: + def __init__(self, portName): + self._port = serial.Serial(portName) + + def configure(self, config: dict) -> None: + self._port.baudrate = parameters["baudrate"] + self._port.bytesize = parameters["bytesize"] + self._port.stopbits = parameters["stopbits"] + self._port.parity = parameters["parity"] + + def send(self, request): + self._port.write(request.encode()) + return self._port.readline().decode() + + def checkConfiguration(self) -> None: + iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._port) + + # JDS6600 configuration taken from manual + assert ispeed == termios.B115200 + assert ospeed == termios.B115200 + assert (cflag & termios.CSIZE) == termios.CS8 + assert (cflag & termios.CSTOPB) == 0 + assert (cflag & (termios.PARENB | termios.PARODD)) == 0 diff --git a/lab_control/jds6600.py b/lab_control/jds6600.py index 32f12af..5f072af 100644 --- a/lab_control/jds6600.py +++ b/lab_control/jds6600.py @@ -2,9 +2,8 @@ Implements partial support for Joy-IT JDS6600 function generator. """ -import serial - from lab_control.function_generator import FunctionGenerator +from lab_control.connection.serial_connection import SerialConnection def _checkChannel(channel: int): assert channel in JDS6600.AVAILABLE_CHANNELS, f"JDS6600: Invalid channel {channel}" @@ -23,32 +22,41 @@ class JDS6600(FunctionGenerator): """ AVAILABLE_CHANNELS = [1, 2] - def __init__(self, portName): + # TODO type hints + def __init__(self, portName: str, overrideConnection=None): super().__init__() - self._port = serial.Serial(portName) - self._port.baudrate = 115200 - self._port.bytesize = serial.EIGHTBITS - self._port.stopbits = serial.STOPBITS_ONE - self._port.parity = serial.PARITY_NONE + + config = { + "baudrate" : 115200, + "bytesize" : 8, + "stopbits" : 1, + "parity" : "N" + } + + if overrideConnection is not None: + self._connection = overrideConnection + else: + self._connection = SerialConnection(portName) + + self._connection.configure(config) def closePort(self) -> None: """ Close the serial port. Instances of this class are no longer usable after this is called. """ - self._port.close() + self._connection.close() + + def _sendRequest(self, opcode: str, args: str="") -> str: + request = f":{opcode}={args}.\r\n" + response = self._connection.send(request) + return response.strip() def _query(self, opcode: str) -> list[str]: # response format: ":{opcode}={v1},{v2}." response = self._sendRequest(opcode) return response[5:-1].split(",") - def _sendRequest(self, opcode: str, args: str="") -> str: - request = f":{opcode}={args}.\r\n" - self._port.write(request.encode()) - responseRaw = self._port.readline() - return responseRaw.decode().strip() - def setOn(self, channel: int) -> None: _checkChannel(channel) diff --git a/lab_control/test/jds6600_test.py b/lab_control/test/jds6600_unittest.py index 59e2d33..c3b283b 100644 --- a/lab_control/test/jds6600_test.py +++ b/lab_control/test/jds6600_unittest.py @@ -2,18 +2,19 @@ import pytest from lab_control.jds6600 import JDS6600 from lab_control.test.mock_jds6600_device import MockJDS6600Device +from lab_control.connection.direct_connection import DirectConnection as MockConnection @pytest.fixture def mockDevice(): - d = MockJDS6600Device() - yield d - d.stop() + return MockJDS6600Device() @pytest.fixture -def uut(mockDevice): - uut = JDS6600(mockDevice.getPortName()) - yield uut - uut.closePort() +def mockConnection(mockDevice): + return MockConnection(mockDevice._handleRequest) + +@pytest.fixture +def uut(mockConnection): + return JDS6600("", mockConnection) def checkNumericalParameter(testValues, writeValue, valueInMock): for ch in JDS6600.AVAILABLE_CHANNELS: @@ -29,12 +30,14 @@ def checkInvalidNumericalParameter(testValues, writeValue, valueInMock): with pytest.raises(AssertionError): writeValue(ch, value) -def test_serialConfiguration(mockDevice): +def test_serialPortConfiguration(mockConnection): with pytest.raises(AssertionError): - mockDevice.checkPortConfiguration() + mockConnection.checkConfiguration() - uut = JDS6600(mockDevice.getPortName()) - mockDevice.checkPortConfiguration() + uut = JDS6600("", mockConnection) + mockConnection.checkConfiguration() + + uut.closePort() def test_channelOnAndOff(uut, mockDevice): for ch in JDS6600.AVAILABLE_CHANNELS: diff --git a/lab_control/test/mock_jds6600_device.py b/lab_control/test/mock_jds6600_device.py index c027573..8b7b440 100644 --- a/lab_control/test/mock_jds6600_device.py +++ b/lab_control/test/mock_jds6600_device.py @@ -1,10 +1,6 @@ -import os -import pty -import termios -import threading import re -class MockJDS6600Device(): +class MockJDS6600Device: class ChannelState: def __init__(self): self.on = False @@ -13,27 +9,8 @@ class MockJDS6600Device(): self.function = None def __init__(self): - self._master, self._slave = pty.openpty() - self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0) - - self._portName = os.ttyname(self._slave) self._channels = [MockJDS6600Device.ChannelState() for i in [1, 2]] - self._injectedFailureCounter = 0 - - self._mainThread = threading.Thread(target=self._mainLoop) - self._mainThread.start() - - def _mainLoop(self) -> None: - while True: - try: - request = self._masterFile.readline().decode().strip() - response = self._handleRequest(request) - - if response is not None: - self._masterFile.write(response.encode()) - except OSError as e: - break def _handleRequest(self, request: str) -> str: pattern = r":(?P<opcode>[wrab])(?P<function>\d+)=(?P<args>.*)\." @@ -91,25 +68,6 @@ class MockJDS6600Device(): # Unknown request format, no response return None - def stop(self) -> None: - self._masterFile.close() - os.close(self._master) - os.close(self._slave) - self._mainThread.join() - - def checkPortConfiguration(self) -> None: - iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._slave) - - # JDS6600 configuration taken from manual - assert ispeed == termios.B115200 - assert ospeed == termios.B115200 - assert (cflag & termios.CSIZE) == termios.CS8 - assert (cflag & termios.CSTOPB) == 0 - assert (cflag & (termios.PARENB | termios.PARODD)) == 0 - - def getPortName(self) -> str: - return self._portName - def isOn(self, ch: int) -> bool: return self._channels[ch - 1].on diff --git a/lab_control/test/virtual_serial_port.py b/lab_control/test/virtual_serial_port.py new file mode 100644 index 0000000..f46e29c --- /dev/null +++ b/lab_control/test/virtual_serial_port.py @@ -0,0 +1,34 @@ +import os +import pty +import termios +import threading + +class VirtualSerialPort: + def __init__(self, requestHandler): + self._master, self._slave = pty.openpty() + self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0) + self._portName = os.ttyname(self._slave) + self._requestHandler = requestHandler + + self._mainThread = threading.Thread(target=self._mainLoop) + self._mainThread.start() + + def stop(self) -> None: + self._masterFile.close() + os.close(self._master) + os.close(self._slave) + self._mainThread.join() + + def _mainLoop(self) -> None: + while True: + try: + request = self._masterFile.readline().decode().strip() + response = self._requestHandler(request) + + if response is not None: + self._masterFile.write(response.encode()) + except OSError as e: + break + + def getPortName(self) -> str: + return self._portName |