Source code for pyads.testserver.testserver

"""Extended ADS TCP/IP server implementation.

Extended ADS TCP/IP server implementation to allow for functional testing of
the ADS protocol without connection to a physical device.

Consists of a server thread which will listen for connections and delegate
each new connection to a separate client thread, allowing for multiple clients
to connect at once.

Each client connection thread listens for incoming data, and delegates parsing
and response construction to the handler. A handler function is injectable at
server level by specifying the `handler` kwarg in the server constructor.

:author: David Browne <davidabrowne@gmail.com>
:license: MIT, see license file or https://opensource.org/licenses/MIT
:created on: 2018-06-11 18:15:53

"""
from __future__ import absolute_import
from typing import Any, List, Type, Optional
from types import TracebackType
import atexit
import select
import socket
import struct
import threading

from .handler import AbstractHandler, logger, null_logger, AmsPacket, AmsResponseData, AmsHeader, AmsTcpHeader
from .basic_handler import BasicHandler
from .advanced_handler import AdvancedHandler

ADS_PORT = 0xBF02


[docs]class AdsTestServer(threading.Thread): """Simple ADS testing server. :ivar function handler: Request handler (see `default_handler` for example) :ivar str ip_address: Host address for server. Defaults to '127.0.0.1' :ivar int port: Host port to listen on, defaults to 48898 """ def __init__( self, handler: "AbstractHandler" = None, ip_address: str = "127.0.0.1", port: int = ADS_PORT, logging: bool = True, *args: Any, **kwargs: Any, ) -> None: self.handler = handler or BasicHandler() self.ip_address = ip_address self.port = port self._run = True global logger logger = logger if logging else null_logger # Keep track of all received AMS packets self.request_history: List[AmsPacket] = [] # Initialize TCP/IP socket server self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Set option to allow instant socket reuse after shutdown self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind((self.ip_address, self.port)) # Make sure we clean up on exit atexit.register(self.close) # Daemonize the server thread # Container for client connection threads self.clients: List[AdsClientConnection] = [] super(AdsTestServer, self).__init__(*args, **kwargs) self.daemon = True def __enter__(self) -> "AdsTestServer": """Enter context.""" self.start() return self def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """Exit context.""" self.close()
[docs] def stop(self) -> None: """Close client connections and stop main server loop.""" # Close all client connections for client in self.clients: client.close() self.clients = [] if self._run: logger.info("Stopping server thread.") # Stop server loop execution self._run = False self.server.close()
[docs] def close(self) -> None: """Close the server thread.""" self.stop()
[docs] def run(self) -> None: """Listen for incoming connections from clients.""" self._run = True # Start server listening self.server.listen(5) logger.info( "Server listening on {0}:{1}".format( self.ip_address or "localhost", self.port ) ) # Server loop while self._run: # Check for new connections at server socket ready, _, _ = select.select([self.server], [], [], 0.1) if ready: # Accept connection from client try: client, address = self.server.accept() except: continue logger.info("New connection from {0}:{1}".format(*address)) # Delegate handling of connection to client thread client_thread = AdsClientConnection( handler=self.handler, client=client, address=address, server=self ) client_thread.daemon = True client_thread.start() self.clients.append(client_thread)
[docs]class AdsClientConnection(threading.Thread): """Connection thread to an ADS client.""" def __init__(self, handler: "AbstractHandler", client: socket.socket, address: str, server: AdsTestServer, *args: Any, **kwargs: Any) -> None: self.handler = handler self.server = server self.client = client self.client_address = address # Make sure we clean up on exit atexit.register(self.close) # Server loop execution flag self._run = True super(AdsClientConnection, self).__init__(*args, **kwargs)
[docs] def stop(self) -> None: """Stop the client thread.""" if self._run: logger.info( "Closing client connection {0}:{1}.".format( *self.client_address) ) self._run = False self.join()
[docs] def close(self) -> None: """Close the client connection.""" if self.is_alive(): self.stop() self.client.close()
[docs] def run(self) -> None: """Listen for data on client connection and delegate requests.""" self._run = True # Main listening loop while self._run: ready, _, _ = select.select([self.client], [], [], 0.1) if not ready: continue data, _ = self.client.recvfrom(4096) if not data: self.client.close() self._run = False continue # Basic data validation if len(data) < 38: logger.warning( "Malformed packet discarded from {0}:{1}:\n\t{data}".format( *self.client_address, data=data ) ) continue # Construct AmsPacket tuple containing request data request_packet = self.construct_request(data) self.server.request_history.append(request_packet) # Delegate request handling and get response data response = self.handler.handle_request(request_packet) if isinstance(response, (AmsResponseData,)): # Convert request, response data (tuples) to a valid ADS # response (bytes) to return to the client response_bytes = self.construct_response(response, request_packet) self.client.send(response_bytes) continue logger.error("Request handler failed to return a valid response.")
[docs] @staticmethod def construct_response(response_data: AmsResponseData, request: AmsPacket) -> bytes: """Construct binary AMS response to return to the client. :param AmsResponseData response_data: Data to include in the response :param AmsPacket request: The originating request for the response """ # Response gets returned to the source, so flip source and target target_net_id = request.ams_header.source_net_id target_port = request.ams_header.source_port source_net_id = request.ams_header.target_net_id source_port = request.ams_header.target_port # Command ID and invoke ID should be same as in the request command_id = request.ams_header.command_id invoke_id = request.ams_header.invoke_id # Use state flags as specified in response data state_flags = response_data.state_flags # Calculate payload length and unpack to binary data ams_length = struct.pack("<I", len(response_data.data)) # Use error code specified in response data error_code = response_data.error_code data = response_data.data # Below we [ab]use `encode` to get a py2/3 compatible binary object # (str in py2, or bytes in py3) # Concatenate ams header data into single binary object ams_header = "".encode("utf-8").join( ( target_net_id, target_port, source_net_id, source_port, command_id, state_flags, ams_length, error_code, invoke_id, data, ) ) ams_tcp_header = "\x00\x00".encode("utf-8") + struct.pack("<I", len( ams_header)) return ams_tcp_header + ams_header
[docs] @staticmethod def construct_request(request_bytes: bytes) -> AmsPacket: """Unpack an AMS packet from binary data. :param bytes request_bytes: The raw request data :rtype AmsPacket: :return: AmsPacket with fields populated from the binary data """ data = request_bytes # Use a shorter name for brevity tcp_header = AmsTcpHeader(data[2:6]) ams_header = AmsHeader( # Extract target/source net ID's and ports data[6:12], data[12:14], data[14:20], data[20:22], # Extract command ID, state flags, and data length data[22:24], data[24:26], data[26:30], # Extract error code, invoke ID, and data data[30:34], data[34:38], data[38:], ) return AmsPacket(tcp_header, ams_header)