From ed42c1a47bc96f6453bb50008481d3467e9254e6 Mon Sep 17 00:00:00 2001 From: Eddy Pedroni Date: Sun, 5 Jun 2022 16:45:32 +0200 Subject: Refactored SDS1000X-E unit tests --- lab_control/connection/direct_connection.py | 2 +- lab_control/connection/tcp_connection.py | 22 ++++++++++ lab_control/sds1000xe.py | 29 +++++++------ lab_control/test/mock_sds1000xe_device.py | 44 +------------------- lab_control/test/sds1000xe_test.py | 64 ----------------------------- lab_control/test/sds1000xe_unittest.py | 61 +++++++++++++++++++++++++++ lab_control/test/virtual_tcp_server.py | 44 ++++++++++++++++++++ 7 files changed, 144 insertions(+), 122 deletions(-) create mode 100644 lab_control/connection/tcp_connection.py delete mode 100644 lab_control/test/sds1000xe_test.py create mode 100644 lab_control/test/sds1000xe_unittest.py create mode 100644 lab_control/test/virtual_tcp_server.py diff --git a/lab_control/connection/direct_connection.py b/lab_control/connection/direct_connection.py index 56a5cf3..8df1e42 100644 --- a/lab_control/connection/direct_connection.py +++ b/lab_control/connection/direct_connection.py @@ -10,7 +10,7 @@ class DirectConnection: def close(self) -> None: self.open = False - def send(self, request: str) -> str: + def send(self, request: str, responseExpected=True) -> str: return self.requestHandler(request) def checkConfiguration(self) -> None: diff --git a/lab_control/connection/tcp_connection.py b/lab_control/connection/tcp_connection.py new file mode 100644 index 0000000..6af710c --- /dev/null +++ b/lab_control/connection/tcp_connection.py @@ -0,0 +1,22 @@ +import socket + +class TCPConnection: + def __init__(self, ip: str, port: int): + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((address, port)) + + def configure(self, config: dict) -> None: + self._socket.settimeout(parameters["timeout"]) + + def send(self, request, responseExpected=True): + self._socket.sendall(request.encode()) + + if responseExpected: + try: + response = self._socket.recv(4096).decode() + except TimeoutError: + response = None + return response + + def checkConfiguration(self) -> None: + pass diff --git a/lab_control/sds1000xe.py b/lab_control/sds1000xe.py index 61fa86d..caa6a9c 100644 --- a/lab_control/sds1000xe.py +++ b/lab_control/sds1000xe.py @@ -1,11 +1,10 @@ """ Implements partial support for Siglent SDS1000X-E series oscilloscopes. """ - -import socket import re from lab_control.oscilloscope import Oscilloscope +from lab_control.connection.tcp_connection import TCPConnection def _checkChannel(channel): assert channel in SDS1000XE.AVAILABLE_CHANNELS, "SDS1000X-E: Invalid channel {channel}" @@ -19,24 +18,25 @@ class SDS1000XE(Oscilloscope): TIMEOUT = 5.0 AVAILABLE_CHANNELS = range(1, 5) - def __init__(self, address): + def __init__(self, address, overrideConnection=None): super().__init__() - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.connect((address, SDS1000XE.PORT)) - self._socket.settimeout(SDS1000XE.TIMEOUT) + if overrideConnection is not None: + self._connection = overrideConnection + else: + self._connection = TCPConnection(address) + + self._connection.configure({"timeout" : SDS1000XE.TIMEOUT}) def _measure(self, channel: int, code: str) -> float: _checkChannel(channel) - pattern = r"C(?P\d):PAVA .+,(?P[\d.E+-]+)\w+" - query = f"C{channel}:PAVA? {code}\r\n" - self._socket.sendall(query.encode()) + request = f"C{channel}:PAVA? {code}\r\n" + response = self._connection.send(request) - try: - # TODO add code to regex - response = self._socket.recv(4096).decode() + if response is not None: + pattern = r"C(?P\d):PAVA .+,(?P[\d.E+-]+)\w+" matches = re.search(pattern, response) measurement = float(matches.group("rawMeasurement")) - except TimeoutError: + else: measurement = None return measurement @@ -56,8 +56,7 @@ class SDS1000XE(Oscilloscope): def setVoltsPerDivision(self, channel: int, volts: float) -> None: _checkChannel(channel) query = f"C{channel}:VDIV {volts:.2E}V\r\n" - self._socket.sendall(query.encode()) - # no response expected + self._connection.send(query, responseExpected=False) def getDivisionsDisplayed(self) -> int: return 8 diff --git a/lab_control/test/mock_sds1000xe_device.py b/lab_control/test/mock_sds1000xe_device.py index 68c2471..04ec07a 100644 --- a/lab_control/test/mock_sds1000xe_device.py +++ b/lab_control/test/mock_sds1000xe_device.py @@ -1,50 +1,12 @@ -import socket -import threading -import atexit import re -IP = "0.0.0.0" -PORT = 5025 - -# Bind server socket when this module is included -_serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -_serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -_serverSocket.bind((IP, PORT)) -_serverSocket.listen(1) - -# Close it when the program exits -def _cleanUp(): - _serverSocket.close() -atexit.register(_cleanUp) - class MockSDS1000XEDevice: def __init__(self): - self._stopFlag = False - self._clientSocket = None - self._mainThread = threading.Thread(target=self._mainLoop) - self._mainThread.start() - # Mock internal values self._channels = [{"AMPL" : None, "VDIV" : None} for i in range(0, 4)] - def _mainLoop(self) -> None: - self._clientSocket, _ = _serverSocket.accept() - self._clientSocket.settimeout(0.1) - - try: - while not self._stopFlag: - try: - request = self._clientSocket.recv(4096).decode() - response = self._handleRequest(request.strip()) - if response is not None: - self._clientSocket.send(response.encode()) - except TimeoutError as e: - pass - finally: - self._clientSocket.close() - def _handleRequest(self, request: str) -> str: - m = re.search(r"C(?P\d):(?P\w+)\??\s(?P.+)", request) + m = re.search(r"C(?P\d):(?P\w+)\??\s(?P.+)", request.strip()) if not m: return None @@ -66,9 +28,7 @@ class MockSDS1000XEDevice: self._channels[channelIndex]["VDIV"] = arg return None - def stop(self) -> None: - self._stopFlag = True - self._mainThread.join() + return None def setAmplitude(self, channel: int, value: float) -> None: self._channels[channel - 1]["AMPL"] = value diff --git a/lab_control/test/sds1000xe_test.py b/lab_control/test/sds1000xe_test.py deleted file mode 100644 index 3774d52..0000000 --- a/lab_control/test/sds1000xe_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import time - -from lab_control.sds1000xe import SDS1000XE -from lab_control.test.mock_sds1000xe_device import MockSDS1000XEDevice - -MOCK_DEVICE_IP = "127.0.0.1" - -@pytest.fixture -def mockDevice(): - d = MockSDS1000XEDevice() - yield d - d.stop() - -@pytest.fixture -def uut(mockDevice): - return SDS1000XE(MOCK_DEVICE_IP) - -def checkFloatMeasurement(testValues, setValue, measureValue): - for channel in SDS1000XE.AVAILABLE_CHANNELS: - for value in testValues: - setValue(channel, value) - measuredValue = measureValue(channel) - assert measuredValue == value - -def test_amplitudeMeasurement(uut, mockDevice): - testValues = [16.23987, 0.0, -0.0164, 10.1] - checkFloatMeasurement(testValues, mockDevice.setAmplitude, uut.measureAmplitude) - -def test_peakToPeakMeasurement(uut, mockDevice): - testValues = [16.23987, 0.0, -0.0164, 10.1] - checkFloatMeasurement(testValues, mockDevice.setPeakToPeak, uut.measurePeakToPeak) - -def test_RMSMeasurement(uut, mockDevice): - testValues = [16.23987, 0.0, -0.0164, 10.1] - checkFloatMeasurement(testValues, mockDevice.setRMS, uut.measureRMS) - -def test_frequencyMeasurement(uut, mockDevice): - testValues = [16.23987, 0.0, -0.0164, 93489.15] - checkFloatMeasurement(testValues, mockDevice.setFrequency, uut.measureFrequency) - -def test_invalidChannel(uut, mockDevice): - # Channel is checked by the UUT before the request is sent - testCases = [-1, 0, 5, None] - testMethods = [uut.measureAmplitude, uut.measurePeakToPeak, uut.measureRMS, uut.measureFrequency] - - for t in testCases: - for m in testMethods: - with pytest.raises(AssertionError): - m(t) - -def test_setVoltsPerDivision(uut, mockDevice): - testValues = [5e-3, 50e-3, 1e0, 5e0, 10e0, 100e0] - - for channel in SDS1000XE.AVAILABLE_CHANNELS: - assert mockDevice.getVoltsPerDivision(channel) is None - - for value in testValues: - uut.setVoltsPerDivision(channel, value) - - time.sleep(0.1) # Allow time for the mock to receive and process the request - - assert mockDevice.getVoltsPerDivision(channel) == value - diff --git a/lab_control/test/sds1000xe_unittest.py b/lab_control/test/sds1000xe_unittest.py new file mode 100644 index 0000000..ce87a6e --- /dev/null +++ b/lab_control/test/sds1000xe_unittest.py @@ -0,0 +1,61 @@ +import pytest + +from lab_control.sds1000xe import SDS1000XE +from lab_control.connection.direct_connection import DirectConnection as MockConnection +from lab_control.test.mock_sds1000xe_device import MockSDS1000XEDevice + +@pytest.fixture +def mockDevice(): + return MockSDS1000XEDevice() + +@pytest.fixture +def mockConnection(mockDevice): + return MockConnection(mockDevice._handleRequest) + +@pytest.fixture +def uut(mockConnection): + return SDS1000XE("", overrideConnection=mockConnection) + +def checkFloatMeasurement(testValues, setValue, measureValue): + for channel in SDS1000XE.AVAILABLE_CHANNELS: + for value in testValues: + setValue(channel, value) + measuredValue = measureValue(channel) + assert measuredValue == value + +def test_amplitudeMeasurement(uut, mockDevice): + testValues = [16.23987, 0.0, -0.0164, 10.1] + checkFloatMeasurement(testValues, mockDevice.setAmplitude, uut.measureAmplitude) + +def test_peakToPeakMeasurement(uut, mockDevice): + testValues = [16.23987, 0.0, -0.0164, 10.1] + checkFloatMeasurement(testValues, mockDevice.setPeakToPeak, uut.measurePeakToPeak) + +def test_RMSMeasurement(uut, mockDevice): + testValues = [16.23987, 0.0, -0.0164, 10.1] + checkFloatMeasurement(testValues, mockDevice.setRMS, uut.measureRMS) + +def test_frequencyMeasurement(uut, mockDevice): + testValues = [16.23987, 0.0, -0.0164, 93489.15] + checkFloatMeasurement(testValues, mockDevice.setFrequency, uut.measureFrequency) + +def test_invalidChannel(uut, mockDevice): + # Channel is checked by the UUT before the request is sent + testCases = [-1, 0, 5, None] + testMethods = [uut.measureAmplitude, uut.measurePeakToPeak, uut.measureRMS, uut.measureFrequency] + + for t in testCases: + for m in testMethods: + with pytest.raises(AssertionError): + m(t) + +def test_setVoltsPerDivision(uut, mockDevice): + testValues = [5e-3, 50e-3, 1e0, 5e0, 10e0, 100e0] + + for channel in SDS1000XE.AVAILABLE_CHANNELS: + assert mockDevice.getVoltsPerDivision(channel) is None + + for value in testValues: + uut.setVoltsPerDivision(channel, value) + assert mockDevice.getVoltsPerDivision(channel) == value + diff --git a/lab_control/test/virtual_tcp_server.py b/lab_control/test/virtual_tcp_server.py new file mode 100644 index 0000000..07a4345 --- /dev/null +++ b/lab_control/test/virtual_tcp_server.py @@ -0,0 +1,44 @@ +import socket +import threading +import atexit + +IP = "0.0.0.0" +PORT = 5025 + +# Bind server socket when this module is included +_serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +_serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +_serverSocket.bind((IP, PORT)) +_serverSocket.listen(1) + +# Close it when the program exits +def _cleanUp(): + _serverSocket.close() +atexit.register(_cleanUp) + +class VirtualTCPServer: + def __init__(self): + self._stopFlag = False + self._clientSocket = None + self._mainThread = threading.Thread(target=self._mainLoop) + self._mainThread.start() + + def _mainLoop(self) -> None: + self._clientSocket, _ = _serverSocket.accept() + self._clientSocket.settimeout(0.1) + + try: + while not self._stopFlag: + try: + request = self._clientSocket.recv(4096).decode() + response = self._handleRequest(request.strip()) + if response is not None: + self._clientSocket.send(response.encode()) + except TimeoutError as e: + pass + finally: + self._clientSocket.close() + + def stop(self) -> None: + self._stopFlag = True + self._mainThread.join() -- cgit v1.2.3