From 498b2543dc4336962c8c235ac7a0b41e175fde07 Mon Sep 17 00:00:00 2001 From: Eddy Pedroni Date: Sun, 5 Jun 2022 16:15:15 +0200 Subject: Refactored JDS6600 tests, serial connection is no longer used --- lab_control/connection/__init__.py | 0 lab_control/connection/direct_connection.py | 20 ++++++ lab_control/connection/serial_connection.py | 26 ++++++++ lab_control/jds6600.py | 38 ++++++----- lab_control/test/jds6600_test.py | 96 ---------------------------- lab_control/test/jds6600_unittest.py | 99 +++++++++++++++++++++++++++++ lab_control/test/mock_jds6600_device.py | 44 +------------ lab_control/test/virtual_serial_port.py | 34 ++++++++++ 8 files changed, 203 insertions(+), 154 deletions(-) create mode 100644 lab_control/connection/__init__.py create mode 100644 lab_control/connection/direct_connection.py create mode 100644 lab_control/connection/serial_connection.py delete mode 100644 lab_control/test/jds6600_test.py create mode 100644 lab_control/test/jds6600_unittest.py create mode 100644 lab_control/test/virtual_serial_port.py (limited to 'lab_control') diff --git a/lab_control/connection/__init__.py b/lab_control/connection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lab_control/connection/direct_connection.py b/lab_control/connection/direct_connection.py new file mode 100644 index 0000000..56a5cf3 --- /dev/null +++ b/lab_control/connection/direct_connection.py @@ -0,0 +1,20 @@ +class DirectConnection: + def __init__(self, requestHandler): + self.requestHandler = requestHandler + self.open = True + self.config = {} + + def configure(self, config: dict) -> None: + self.config = config + + def close(self) -> None: + self.open = False + + def send(self, request: str) -> str: + return self.requestHandler(request) + + def checkConfiguration(self) -> None: + assert self.config.get("baudrate") == 115200 + assert self.config.get("bytesize") == 8 + assert self.config.get("stopbits") == 1 + assert self.config.get("parity") == "N" diff --git a/lab_control/connection/serial_connection.py b/lab_control/connection/serial_connection.py new file mode 100644 index 0000000..282494b --- /dev/null +++ b/lab_control/connection/serial_connection.py @@ -0,0 +1,26 @@ +import termios +import serial + +class SerialConnection: + def __init__(self, portName): + self._port = serial.Serial(portName) + + def configure(self, config: dict) -> None: + self._port.baudrate = parameters["baudrate"] + self._port.bytesize = parameters["bytesize"] + self._port.stopbits = parameters["stopbits"] + self._port.parity = parameters["parity"] + + def send(self, request): + self._port.write(request.encode()) + return self._port.readline().decode() + + def checkConfiguration(self) -> None: + iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._port) + + # JDS6600 configuration taken from manual + assert ispeed == termios.B115200 + assert ospeed == termios.B115200 + assert (cflag & termios.CSIZE) == termios.CS8 + assert (cflag & termios.CSTOPB) == 0 + assert (cflag & (termios.PARENB | termios.PARODD)) == 0 diff --git a/lab_control/jds6600.py b/lab_control/jds6600.py index 32f12af..5f072af 100644 --- a/lab_control/jds6600.py +++ b/lab_control/jds6600.py @@ -2,9 +2,8 @@ Implements partial support for Joy-IT JDS6600 function generator. """ -import serial - from lab_control.function_generator import FunctionGenerator +from lab_control.connection.serial_connection import SerialConnection def _checkChannel(channel: int): assert channel in JDS6600.AVAILABLE_CHANNELS, f"JDS6600: Invalid channel {channel}" @@ -23,32 +22,41 @@ class JDS6600(FunctionGenerator): """ AVAILABLE_CHANNELS = [1, 2] - def __init__(self, portName): + # TODO type hints + def __init__(self, portName: str, overrideConnection=None): super().__init__() - self._port = serial.Serial(portName) - self._port.baudrate = 115200 - self._port.bytesize = serial.EIGHTBITS - self._port.stopbits = serial.STOPBITS_ONE - self._port.parity = serial.PARITY_NONE + + config = { + "baudrate" : 115200, + "bytesize" : 8, + "stopbits" : 1, + "parity" : "N" + } + + if overrideConnection is not None: + self._connection = overrideConnection + else: + self._connection = SerialConnection(portName) + + self._connection.configure(config) def closePort(self) -> None: """ Close the serial port. Instances of this class are no longer usable after this is called. """ - self._port.close() + self._connection.close() + + def _sendRequest(self, opcode: str, args: str="") -> str: + request = f":{opcode}={args}.\r\n" + response = self._connection.send(request) + return response.strip() def _query(self, opcode: str) -> list[str]: # response format: ":{opcode}={v1},{v2}." response = self._sendRequest(opcode) return response[5:-1].split(",") - def _sendRequest(self, opcode: str, args: str="") -> str: - request = f":{opcode}={args}.\r\n" - self._port.write(request.encode()) - responseRaw = self._port.readline() - return responseRaw.decode().strip() - def setOn(self, channel: int) -> None: _checkChannel(channel) diff --git a/lab_control/test/jds6600_test.py b/lab_control/test/jds6600_test.py deleted file mode 100644 index 59e2d33..0000000 --- a/lab_control/test/jds6600_test.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest - -from lab_control.jds6600 import JDS6600 -from lab_control.test.mock_jds6600_device import MockJDS6600Device - -@pytest.fixture -def mockDevice(): - d = MockJDS6600Device() - yield d - d.stop() - -@pytest.fixture -def uut(mockDevice): - uut = JDS6600(mockDevice.getPortName()) - yield uut - uut.closePort() - -def checkNumericalParameter(testValues, writeValue, valueInMock): - for ch in JDS6600.AVAILABLE_CHANNELS: - assert valueInMock(ch) is None - - for value in testValues: - writeValue(ch, value) - assert valueInMock(ch) == value - -def checkInvalidNumericalParameter(testValues, writeValue, valueInMock): - for ch in JDS6600.AVAILABLE_CHANNELS: - for value in testValues: - with pytest.raises(AssertionError): - writeValue(ch, value) - -def test_serialConfiguration(mockDevice): - with pytest.raises(AssertionError): - mockDevice.checkPortConfiguration() - - uut = JDS6600(mockDevice.getPortName()) - mockDevice.checkPortConfiguration() - -def test_channelOnAndOff(uut, mockDevice): - for ch in JDS6600.AVAILABLE_CHANNELS: - assert not mockDevice.isOn(ch) - uut.setOn(ch) - assert mockDevice.isOn(ch) - uut.setOff(ch) - assert not mockDevice.isOn(ch) - -def test_setFrequency(uut, mockDevice): - checkNumericalParameter([0.0, 100.0, 100000.0, 60000000.0], uut.setFrequency, mockDevice.getFrequency) - -def test_setInvalidFrequency(uut, mockDevice): - checkInvalidNumericalParameter([-10.0, 60000000.1, None], uut.setFrequency, mockDevice.getFrequency) - -def test_setAmplitude(uut, mockDevice): - checkNumericalParameter([0.0, 0.1, 1.0, 10.0, 20.0], uut.setAmplitude, mockDevice.getAmplitude) - -def test_setInvalidAmplitude(uut, mockDevice): - checkInvalidNumericalParameter([-0.1, -10.0, 20.1, None], uut.setAmplitude, mockDevice.getAmplitude) - -def test_setFunction(uut, mockDevice): - checkNumericalParameter(range(0, 17), uut.setFunction, mockDevice.getFunction) - -def test_setInvalidFunction(uut, mockDevice): - checkInvalidNumericalParameter([-1, -10, 17, 20, None], uut.setFunction, mockDevice.getFunction) - -def test_invalidChannel(uut): - testMethods = [uut.setFrequency, uut.setAmplitude, uut.setFunction] - for ch in [-1, 0, 3, None]: - for method in testMethods: - with pytest.raises(AssertionError): - method(ch, 0) - - with pytest.raises(AssertionError): - uut.setOn(ch) - - with pytest.raises(AssertionError): - uut.setOff(ch) - -def test_setFrequencySingleFailure(uut, mockDevice): - testFrequency = 1000.0 - testChannel = 1 - assert mockDevice.getFrequency(testChannel) is None - - mockDevice.injectFailures(1) - uut.setFrequency(testChannel, testFrequency) - - assert mockDevice.getFrequency(testChannel) == testFrequency - -def test_setFrequencyMultipleFailures(uut, mockDevice): - testFrequency = 1000.0 - testChannel = 1 - assert mockDevice.getFrequency(testChannel) is None - - mockDevice.injectFailures(2) - uut.setFrequency(testChannel, testFrequency) - - assert mockDevice.getFrequency(testChannel) == 0.0 diff --git a/lab_control/test/jds6600_unittest.py b/lab_control/test/jds6600_unittest.py new file mode 100644 index 0000000..c3b283b --- /dev/null +++ b/lab_control/test/jds6600_unittest.py @@ -0,0 +1,99 @@ +import pytest + +from lab_control.jds6600 import JDS6600 +from lab_control.test.mock_jds6600_device import MockJDS6600Device +from lab_control.connection.direct_connection import DirectConnection as MockConnection + +@pytest.fixture +def mockDevice(): + return MockJDS6600Device() + +@pytest.fixture +def mockConnection(mockDevice): + return MockConnection(mockDevice._handleRequest) + +@pytest.fixture +def uut(mockConnection): + return JDS6600("", mockConnection) + +def checkNumericalParameter(testValues, writeValue, valueInMock): + for ch in JDS6600.AVAILABLE_CHANNELS: + assert valueInMock(ch) is None + + for value in testValues: + writeValue(ch, value) + assert valueInMock(ch) == value + +def checkInvalidNumericalParameter(testValues, writeValue, valueInMock): + for ch in JDS6600.AVAILABLE_CHANNELS: + for value in testValues: + with pytest.raises(AssertionError): + writeValue(ch, value) + +def test_serialPortConfiguration(mockConnection): + with pytest.raises(AssertionError): + mockConnection.checkConfiguration() + + uut = JDS6600("", mockConnection) + mockConnection.checkConfiguration() + + uut.closePort() + +def test_channelOnAndOff(uut, mockDevice): + for ch in JDS6600.AVAILABLE_CHANNELS: + assert not mockDevice.isOn(ch) + uut.setOn(ch) + assert mockDevice.isOn(ch) + uut.setOff(ch) + assert not mockDevice.isOn(ch) + +def test_setFrequency(uut, mockDevice): + checkNumericalParameter([0.0, 100.0, 100000.0, 60000000.0], uut.setFrequency, mockDevice.getFrequency) + +def test_setInvalidFrequency(uut, mockDevice): + checkInvalidNumericalParameter([-10.0, 60000000.1, None], uut.setFrequency, mockDevice.getFrequency) + +def test_setAmplitude(uut, mockDevice): + checkNumericalParameter([0.0, 0.1, 1.0, 10.0, 20.0], uut.setAmplitude, mockDevice.getAmplitude) + +def test_setInvalidAmplitude(uut, mockDevice): + checkInvalidNumericalParameter([-0.1, -10.0, 20.1, None], uut.setAmplitude, mockDevice.getAmplitude) + +def test_setFunction(uut, mockDevice): + checkNumericalParameter(range(0, 17), uut.setFunction, mockDevice.getFunction) + +def test_setInvalidFunction(uut, mockDevice): + checkInvalidNumericalParameter([-1, -10, 17, 20, None], uut.setFunction, mockDevice.getFunction) + +def test_invalidChannel(uut): + testMethods = [uut.setFrequency, uut.setAmplitude, uut.setFunction] + for ch in [-1, 0, 3, None]: + for method in testMethods: + with pytest.raises(AssertionError): + method(ch, 0) + + with pytest.raises(AssertionError): + uut.setOn(ch) + + with pytest.raises(AssertionError): + uut.setOff(ch) + +def test_setFrequencySingleFailure(uut, mockDevice): + testFrequency = 1000.0 + testChannel = 1 + assert mockDevice.getFrequency(testChannel) is None + + mockDevice.injectFailures(1) + uut.setFrequency(testChannel, testFrequency) + + assert mockDevice.getFrequency(testChannel) == testFrequency + +def test_setFrequencyMultipleFailures(uut, mockDevice): + testFrequency = 1000.0 + testChannel = 1 + assert mockDevice.getFrequency(testChannel) is None + + mockDevice.injectFailures(2) + uut.setFrequency(testChannel, testFrequency) + + assert mockDevice.getFrequency(testChannel) == 0.0 diff --git a/lab_control/test/mock_jds6600_device.py b/lab_control/test/mock_jds6600_device.py index c027573..8b7b440 100644 --- a/lab_control/test/mock_jds6600_device.py +++ b/lab_control/test/mock_jds6600_device.py @@ -1,10 +1,6 @@ -import os -import pty -import termios -import threading import re -class MockJDS6600Device(): +class MockJDS6600Device: class ChannelState: def __init__(self): self.on = False @@ -13,27 +9,8 @@ class MockJDS6600Device(): self.function = None def __init__(self): - self._master, self._slave = pty.openpty() - self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0) - - self._portName = os.ttyname(self._slave) self._channels = [MockJDS6600Device.ChannelState() for i in [1, 2]] - self._injectedFailureCounter = 0 - - self._mainThread = threading.Thread(target=self._mainLoop) - self._mainThread.start() - - def _mainLoop(self) -> None: - while True: - try: - request = self._masterFile.readline().decode().strip() - response = self._handleRequest(request) - - if response is not None: - self._masterFile.write(response.encode()) - except OSError as e: - break def _handleRequest(self, request: str) -> str: pattern = r":(?P[wrab])(?P\d+)=(?P.*)\." @@ -91,25 +68,6 @@ class MockJDS6600Device(): # Unknown request format, no response return None - def stop(self) -> None: - self._masterFile.close() - os.close(self._master) - os.close(self._slave) - self._mainThread.join() - - def checkPortConfiguration(self) -> None: - iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._slave) - - # JDS6600 configuration taken from manual - assert ispeed == termios.B115200 - assert ospeed == termios.B115200 - assert (cflag & termios.CSIZE) == termios.CS8 - assert (cflag & termios.CSTOPB) == 0 - assert (cflag & (termios.PARENB | termios.PARODD)) == 0 - - def getPortName(self) -> str: - return self._portName - def isOn(self, ch: int) -> bool: return self._channels[ch - 1].on diff --git a/lab_control/test/virtual_serial_port.py b/lab_control/test/virtual_serial_port.py new file mode 100644 index 0000000..f46e29c --- /dev/null +++ b/lab_control/test/virtual_serial_port.py @@ -0,0 +1,34 @@ +import os +import pty +import termios +import threading + +class VirtualSerialPort: + def __init__(self, requestHandler): + self._master, self._slave = pty.openpty() + self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0) + self._portName = os.ttyname(self._slave) + self._requestHandler = requestHandler + + self._mainThread = threading.Thread(target=self._mainLoop) + self._mainThread.start() + + def stop(self) -> None: + self._masterFile.close() + os.close(self._master) + os.close(self._slave) + self._mainThread.join() + + def _mainLoop(self) -> None: + while True: + try: + request = self._masterFile.readline().decode().strip() + response = self._requestHandler(request) + + if response is not None: + self._masterFile.write(response.encode()) + except OSError as e: + break + + def getPortName(self) -> str: + return self._portName -- cgit v1.2.3