summaryrefslogtreecommitdiffstats
path: root/lab_control/test
diff options
context:
space:
mode:
Diffstat (limited to 'lab_control/test')
-rw-r--r--lab_control/test/jds6600_unittest.py (renamed from lab_control/test/jds6600_test.py)25
-rw-r--r--lab_control/test/mock_jds6600_device.py44
-rw-r--r--lab_control/test/virtual_serial_port.py34
3 files changed, 49 insertions, 54 deletions
diff --git a/lab_control/test/jds6600_test.py b/lab_control/test/jds6600_unittest.py
index 59e2d33..c3b283b 100644
--- a/lab_control/test/jds6600_test.py
+++ b/lab_control/test/jds6600_unittest.py
@@ -2,18 +2,19 @@ import pytest
from lab_control.jds6600 import JDS6600
from lab_control.test.mock_jds6600_device import MockJDS6600Device
+from lab_control.connection.direct_connection import DirectConnection as MockConnection
@pytest.fixture
def mockDevice():
- d = MockJDS6600Device()
- yield d
- d.stop()
+ return MockJDS6600Device()
@pytest.fixture
-def uut(mockDevice):
- uut = JDS6600(mockDevice.getPortName())
- yield uut
- uut.closePort()
+def mockConnection(mockDevice):
+ return MockConnection(mockDevice._handleRequest)
+
+@pytest.fixture
+def uut(mockConnection):
+ return JDS6600("", mockConnection)
def checkNumericalParameter(testValues, writeValue, valueInMock):
for ch in JDS6600.AVAILABLE_CHANNELS:
@@ -29,12 +30,14 @@ def checkInvalidNumericalParameter(testValues, writeValue, valueInMock):
with pytest.raises(AssertionError):
writeValue(ch, value)
-def test_serialConfiguration(mockDevice):
+def test_serialPortConfiguration(mockConnection):
with pytest.raises(AssertionError):
- mockDevice.checkPortConfiguration()
+ mockConnection.checkConfiguration()
- uut = JDS6600(mockDevice.getPortName())
- mockDevice.checkPortConfiguration()
+ uut = JDS6600("", mockConnection)
+ mockConnection.checkConfiguration()
+
+ uut.closePort()
def test_channelOnAndOff(uut, mockDevice):
for ch in JDS6600.AVAILABLE_CHANNELS:
diff --git a/lab_control/test/mock_jds6600_device.py b/lab_control/test/mock_jds6600_device.py
index c027573..8b7b440 100644
--- a/lab_control/test/mock_jds6600_device.py
+++ b/lab_control/test/mock_jds6600_device.py
@@ -1,10 +1,6 @@
-import os
-import pty
-import termios
-import threading
import re
-class MockJDS6600Device():
+class MockJDS6600Device:
class ChannelState:
def __init__(self):
self.on = False
@@ -13,27 +9,8 @@ class MockJDS6600Device():
self.function = None
def __init__(self):
- self._master, self._slave = pty.openpty()
- self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0)
-
- self._portName = os.ttyname(self._slave)
self._channels = [MockJDS6600Device.ChannelState() for i in [1, 2]]
-
self._injectedFailureCounter = 0
-
- self._mainThread = threading.Thread(target=self._mainLoop)
- self._mainThread.start()
-
- def _mainLoop(self) -> None:
- while True:
- try:
- request = self._masterFile.readline().decode().strip()
- response = self._handleRequest(request)
-
- if response is not None:
- self._masterFile.write(response.encode())
- except OSError as e:
- break
def _handleRequest(self, request: str) -> str:
pattern = r":(?P<opcode>[wrab])(?P<function>\d+)=(?P<args>.*)\."
@@ -91,25 +68,6 @@ class MockJDS6600Device():
# Unknown request format, no response
return None
- def stop(self) -> None:
- self._masterFile.close()
- os.close(self._master)
- os.close(self._slave)
- self._mainThread.join()
-
- def checkPortConfiguration(self) -> None:
- iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._slave)
-
- # JDS6600 configuration taken from manual
- assert ispeed == termios.B115200
- assert ospeed == termios.B115200
- assert (cflag & termios.CSIZE) == termios.CS8
- assert (cflag & termios.CSTOPB) == 0
- assert (cflag & (termios.PARENB | termios.PARODD)) == 0
-
- def getPortName(self) -> str:
- return self._portName
-
def isOn(self, ch: int) -> bool:
return self._channels[ch - 1].on
diff --git a/lab_control/test/virtual_serial_port.py b/lab_control/test/virtual_serial_port.py
new file mode 100644
index 0000000..f46e29c
--- /dev/null
+++ b/lab_control/test/virtual_serial_port.py
@@ -0,0 +1,34 @@
+import os
+import pty
+import termios
+import threading
+
+class VirtualSerialPort:
+ def __init__(self, requestHandler):
+ self._master, self._slave = pty.openpty()
+ self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0)
+ self._portName = os.ttyname(self._slave)
+ self._requestHandler = requestHandler
+
+ self._mainThread = threading.Thread(target=self._mainLoop)
+ self._mainThread.start()
+
+ def stop(self) -> None:
+ self._masterFile.close()
+ os.close(self._master)
+ os.close(self._slave)
+ self._mainThread.join()
+
+ def _mainLoop(self) -> None:
+ while True:
+ try:
+ request = self._masterFile.readline().decode().strip()
+ response = self._requestHandler(request)
+
+ if response is not None:
+ self._masterFile.write(response.encode())
+ except OSError as e:
+ break
+
+ def getPortName(self) -> str:
+ return self._portName