From 832ae23f00ae94584da0a1929c9feb63b34badbc Mon Sep 17 00:00:00 2001 From: Linn Mattsson Date: Tue, 11 Oct 2022 07:31:46 +0000 Subject: [PATCH 1/3] Make websocket's API more intuitive Functions connect() and accept() are using http functionality, like sending requests and headers. Let's create separate functions with more intuitive names for these calls. This allows subclasses to override these functions, as well as makes the code easier to understand at a glance. --- websockify/websocket.py | 42 ++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/websockify/websocket.py b/websockify/websocket.py index f0db3ba..5a819f7 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -158,19 +158,19 @@ class WebSocket(object): if not path: path = "/" - self._queue_str("GET %s HTTP/1.1\r\n" % path) - self._queue_str("Host: %s\r\n" % uri.hostname) - self._queue_str("Upgrade: websocket\r\n") - self._queue_str("Connection: upgrade\r\n") - self._queue_str("Sec-WebSocket-Key: %s\r\n" % self._key) - self._queue_str("Sec-WebSocket-Version: 13\r\n") + self.send_request("GET", path) + self.send_header("Host", uri.hostname) + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "upgrade") + self.send_header("Sec-WebSocket-Key", self._key) + self.send_header("Sec-WebSocket-Version", 13) if origin is not None: - self._queue_str("Origin: %s\r\n" % origin) + self.send_header("Origin", origin) if len(protocols) > 0: - self._queue_str("Sec-WebSocket-Protocol: %s\r\n" % ", ".join(protocols)) + self.send_header("Sec-WebSocket-Protocol", ", ".join(protocols)) - self._queue_str("\r\n") + self.end_headers() self._state = "send_headers" @@ -283,15 +283,15 @@ class WebSocket(object): if self.protocol not in protocols: raise Exception('Invalid protocol selected') - self._queue_str("HTTP/1.1 101 Switching Protocols\r\n") - self._queue_str("Upgrade: websocket\r\n") - self._queue_str("Connection: Upgrade\r\n") - self._queue_str("Sec-WebSocket-Accept: %s\r\n" % accept) + self.send_response(101, "Switching Protocols") + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", accept) if self.protocol: - self._queue_str("Sec-WebSocket-Protocol: %s\r\n" % self.protocol) + self.send_header("Sec-WebSocket-Protocol", self.protocol) - self._queue_str("\r\n") + self.end_headers() self._state = "flush" @@ -447,6 +447,18 @@ class WebSocket(object): return len(msg) + def send_response(self, code, message): + self._queue_str("HTTP/1.1 %d %s\r\n" % (code, message)) + + def send_header(self, keyword, value): + self._queue_str("%s: %s\r\n" % (keyword, value)) + + def end_headers(self): + self._queue_str("\r\n") + + def send_request(self, type, path): + self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path)) + def ping(self, data=b''): """Write a ping message to the WebSocket From 4695f967288a52a480e42c5267ed512f18ed5856 Mon Sep 17 00:00:00 2001 From: Linn Mattsson Date: Tue, 11 Oct 2022 07:31:48 +0000 Subject: [PATCH 2/3] Add new websocket class HttpWebSocket This class acts as a glue between websocket and http functionality by taking a 'request_handler' and using its functions for send_response(), send_header() and end_headers(). --- tests/test_websocketserver.py | 69 ++++++++++++++++++++++++++++++++++ websockify/websocketserver.py | 21 ++++++++++- websockify/websockifyserver.py | 4 +- 3 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 tests/test_websocketserver.py diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py new file mode 100644 index 0000000..0e37e3d --- /dev/null +++ b/tests/test_websocketserver.py @@ -0,0 +1,69 @@ + +""" Unit tests for websocketserver """ +import unittest +from unittest.mock import patch, MagicMock + +from websockify.websocketserver import HttpWebSocket + + +class HttpWebSocketTest(unittest.TestCase): + @patch("websockify.websocketserver.WebSocket.__init__", autospec=True) + def test_constructor(self, websock): + # Given + req_obj = MagicMock() + + # When + sock = HttpWebSocket(req_obj) + + # Then + websock.assert_called_once_with(sock) + self.assertEqual(sock.request_handler, req_obj) + + @patch("websockify.websocketserver.WebSocket.__init__", MagicMock(autospec=True)) + def test_send_response(self): + # Given + req_obj = MagicMock() + sock = HttpWebSocket(req_obj) + + # When + sock.send_response(200, "message") + + # Then + req_obj.send_response.assert_called_once_with(200, "message") + + @patch("websockify.websocketserver.WebSocket.__init__", MagicMock(autospec=True)) + def test_send_response_default_message(self): + # Given + req_obj = MagicMock() + sock = HttpWebSocket(req_obj) + + # When + sock.send_response(200) + + # Then + req_obj.send_response.assert_called_once_with(200, None) + + @patch("websockify.websocketserver.WebSocket.__init__", MagicMock(autospec=True)) + def test_send_header(self): + # Given + req_obj = MagicMock() + sock = HttpWebSocket(req_obj) + + # When + sock.send_header("keyword", "value") + + # Then + req_obj.send_header.assert_called_once_with("keyword", "value") + + @patch("websockify.websocketserver.WebSocket.__init__", MagicMock(autospec=True)) + def test_end_headers(self): + # Given + req_obj = MagicMock() + sock = HttpWebSocket(req_obj) + + # When + sock.end_headers() + + # Then + req_obj.end_headers.assert_called_once_with() + diff --git a/websockify/websocketserver.py b/websockify/websocketserver.py index 9088fe9..287f520 100644 --- a/websockify/websocketserver.py +++ b/websockify/websocketserver.py @@ -12,6 +12,23 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError +class HttpWebSocket(WebSocket): + """Class to glue websocket and http request functionality together""" + def __init__(self, request_handler): + super().__init__() + + self.request_handler = request_handler + + def send_response(self, code, message=None): + self.request_handler.send_response(code, message) + + def send_header(self, keyword, value): + self.request_handler.send_header(keyword, value) + + def end_headers(self): + self.request_handler.end_headers() + + class WebSocketRequestHandlerMixIn: """WebSocket request handler mix-in class @@ -25,7 +42,7 @@ class WebSocketRequestHandlerMixIn: use for the WebSocket connection. """ - SocketClass = WebSocket + SocketClass = HttpWebSocket def handle_one_request(self): """Extended request handler @@ -59,7 +76,7 @@ class WebSocketRequestHandlerMixIn: The WebSocket object will then replace the request object and handle_websocket() will be called. """ - websocket = self.SocketClass() + websocket = self.SocketClass(self) try: websocket.accept(self.request, self.headers) except Exception: diff --git a/websockify/websockifyserver.py b/websockify/websockifyserver.py index 0199e42..3e13fa1 100644 --- a/websockify/websockifyserver.py +++ b/websockify/websockifyserver.py @@ -29,10 +29,10 @@ if sys.platform == 'win32': # make sockets pickle-able/inheritable import multiprocessing.reduction -from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError +from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError from websockify.websocketserver import WebSocketRequestHandlerMixIn -class CompatibleWebSocket(WebSocket): +class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): def select_subprotocol(self, protocols): # Handle old websockify clients that still specify a sub-protocol if 'binary' in protocols: From be7b868518801c376168cfb6c40bacdc8c49b3b2 Mon Sep 17 00:00:00 2001 From: Linn Mattsson Date: Tue, 11 Oct 2022 07:31:56 +0000 Subject: [PATCH 3/3] Remove logging from handle_upgrade() The logging should be handled directly in send_response() instead, which is the default of Python's built-in send_response(). Remove this manual logging to avoid logging the same call twice. --- websockify/websocketserver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/websockify/websocketserver.py b/websockify/websocketserver.py index 287f520..4e62f2e 100644 --- a/websockify/websocketserver.py +++ b/websockify/websocketserver.py @@ -84,8 +84,6 @@ class WebSocketRequestHandlerMixIn: self.send_error(400, str(exc)) return - self.log_request(101) - self.request = websocket # Other requests cannot follow Websocket data