diff --git a/tests/echo.py b/tests/echo.py index 780891c..1356400 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -14,6 +14,7 @@ import os, sys, select, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler + class WebSocketEcho(WebSockifyRequestHandler): """ WebSockets server that echos back whatever is received from the @@ -50,6 +51,7 @@ class WebSocketEcho(WebSockifyRequestHandler): if closed: break + if __name__ == '__main__': parser = optparse.OptionParser(usage="%prog [options] listen_port") parser.add_option("--verbose", "-v", action="store_true", @@ -73,4 +75,3 @@ if __name__ == '__main__': opts.web = "." server = WebSockifyServer(WebSocketEcho, **opts.__dict__) server.start_server() - diff --git a/tests/echo_client.py b/tests/echo_client.py index 4f238f6..6421ec3 100755 --- a/tests/echo_client.py +++ b/tests/echo_client.py @@ -22,6 +22,7 @@ print("Connecting to %s..." % URL) sock.connect(URL) print("Connected.") + def send(msg): while True: try: @@ -36,6 +37,7 @@ def send(msg): ins, outs, excepts = select.select([], [sock], []) if excepts: raise Exception("Socket exception") + def read(): while True: try: @@ -47,6 +49,7 @@ def read(): ins, outs, excepts = select.select([], [sock], []) if excepts: raise Exception("Socket exception") + counter = 1 while True: msg = "Message #%d" % counter diff --git a/tests/load.py b/tests/load.py index 548d47f..1ace2f8 100755 --- a/tests/load.py +++ b/tests/load.py @@ -10,6 +10,7 @@ import sys, os, select, random, time, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler + class WebSocketLoadServer(WebSockifyServer): recv_cnt = 0 @@ -85,7 +86,6 @@ class WebSocketLoad(WebSockifyRequestHandler): return data - def check(self, frames): err = "" @@ -165,4 +165,3 @@ if __name__ == '__main__': opts.web = "." server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server.start_server() - diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index e1b967b..e4ec7dd 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -9,6 +9,7 @@ from jwcrypto import jwt, jwk from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis + class ParseSourceArgumentsTestCase(unittest.TestCase): def test_parameterized(self): params = [ @@ -31,6 +32,7 @@ class ParseSourceArgumentsTestCase(unittest.TestCase): for src, args in params: self.assertEqual(args, parse_source_args(src)) + class ReadOnlyTokenFileTestCase(unittest.TestCase): def test_empty(self): mock_source_file = MagicMock() @@ -201,6 +203,7 @@ class JWSTokenTestCase(unittest.TestCase): self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") + class TokenRedisTestCase(unittest.TestCase): def setUp(self): try: diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 42d658d..7f89312 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -18,6 +18,7 @@ import unittest from websockify import websocket + class FakeSocket: def __init__(self): self.data = b'' @@ -26,6 +27,7 @@ class FakeSocket: self.data += buf return len(buf) + class AcceptTestCase(unittest.TestCase): def test_success(self): ws = websocket.WebSocket() @@ -116,6 +118,7 @@ class AcceptTestCase(unittest.TestCase): 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', 'Sec-WebSocket-Protocol': 'foobar,gazonk'}) + class PingPongTest(unittest.TestCase): def setUp(self): self.ws = websocket.WebSocket() @@ -142,6 +145,7 @@ class PingPongTest(unittest.TestCase): self.ws.pong(b'foo') self.assertEqual(self.sock.data, b'\x8a\x03foo') + class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_decode_hybi_text(self): buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index 27c9317..2da8b08 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -56,6 +56,7 @@ class FakeServer: self.ssl_target = None self.unix_target = None + class ProxyRequestHandlerTestCase(unittest.TestCase): def setUp(self): super().setUp() @@ -126,4 +127,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): self.handler.server.target_host = "someotherhost" self.handler.auth_connection() - diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py index 0e37e3d..8822572 100644 --- a/tests/test_websocketserver.py +++ b/tests/test_websocketserver.py @@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase): # Then req_obj.end_headers.assert_called_once_with() - diff --git a/tests/test_websockifyserver.py b/tests/test_websockifyserver.py index c267330..91ca0e7 100644 --- a/tests/test_websockifyserver.py +++ b/tests/test_websockifyserver.py @@ -231,6 +231,7 @@ class WebSockifyServerTestCase(unittest.TestCase): def test_do_handshake_no_ssl(self): class FakeHandler: CALLED = False + def __init__(self, *args, **kwargs): type(self).CALLED = True @@ -286,12 +287,16 @@ class WebSockifyServerTestCase(unittest.TestCase): def __init__(self, purpose): self.verify_mode = None self.options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): raise ssl.SSLError(ssl.SSL_ERROR_EOF) @@ -317,17 +322,23 @@ class WebSockifyServerTestCase(unittest.TestCase): class fake_create_default_context(): CIPHERS = '' + def __init__(self, purpose): self.verify_mode = None self.options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): pass + def set_ciphers(self, ciphers_to_set): fake_create_default_context.CIPHERS = ciphers_to_set @@ -352,19 +363,26 @@ class WebSockifyServerTestCase(unittest.TestCase): class fake_create_default_context: OPTIONS = 0 + def __init__(self, purpose): self.verify_mode = None self._options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): pass + def get_options(self): return self._options + def set_options(self, val): fake_create_default_context.OPTIONS = val options = property(get_options, set_options) diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 36fac52..dda188e 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -76,6 +76,7 @@ class BasicHTTPAuth(): raise AuthenticationError(response_code=401, response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + class ExpectOrigin(): def __init__(self, src=None): if src is None: @@ -88,6 +89,7 @@ class ExpectOrigin(): if origin is None or origin not in self.source: raise InvalidOriginError(expected=self.source, actual=origin) + class ClientCertCNAuth(): """Verifies client by SSL certificate. Specify src as whitespace separated list of common names.""" diff --git a/websockify/sysloghandler.py b/websockify/sysloghandler.py index 2b5f32f..cfd995f 100644 --- a/websockify/sysloghandler.py +++ b/websockify/sysloghandler.py @@ -16,11 +16,8 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): _max_ident = 24 #safer for old daemons _send_length = False _tail = '\n' - - ident = None - def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT), facility=handlers.SysLogHandler.LOG_USER, socktype=None, ident=None, legacy=False): @@ -46,7 +43,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): super().__init__(address, facility, socktype) - def emit(self, record): """ Emit a record. diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index d582032..bc27960 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -84,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile): return super().lookup(token) + class TokenFileName(BasePlugin): # source is a directory # token is filename diff --git a/websockify/websocket.py b/websockify/websocket.py index 3bfb594..1dbda7d 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -31,11 +31,15 @@ except ImportError: warnings.warn("no 'numpy' module, HyBi protocol will be slower") numpy = None + class WebSocketWantReadError(ssl.SSLWantReadError): pass + + class WebSocketWantWriteError(ssl.SSLWantWriteError): pass + class WebSocket: """WebSocket protocol socket like class. @@ -873,4 +877,3 @@ class WebSocket: f['payload'] = buf[hlen:(hlen+length)] return f - diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 70da60d..de75a34 100644 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -20,6 +20,7 @@ from websockify import websockifyserver from websockify import auth_plugins as auth from urllib.parse import parse_qs, urlparse + class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler): buffer_size = 65536 @@ -248,7 +249,6 @@ Traffic Legend: self.server.target_host, self.server.target_port) raise self.CClose(closed['code'], closed['reason']) - if target in outs: # Send queued client data to the target dat = tqueue.pop(0) @@ -260,7 +260,6 @@ Traffic Legend: tqueue.insert(0, dat[sent:]) self.print_traffic(".>") - if target in ins: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) @@ -283,6 +282,7 @@ Traffic Legend: cqueue.append(buf) self.print_traffic("{") + class WebSocketProxy(websockifyserver.WebSockifyServer): """ Proxy traffic to and from a WebSockets client to a normal TCP @@ -427,6 +427,7 @@ SSL_OPTIONS = { ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2, } + def select_ssl_version(version): """Returns SSL options for the most secure TSL version available on this Python version""" @@ -444,6 +445,7 @@ def select_ssl_version(version): return SSL_OPTIONS[fallback] + def websockify_init(): # Setup basic logging to stderr. stderr_handler = logging.StreamHandler() @@ -562,9 +564,7 @@ def websockify_init(): (opts, args) = parser.parse_args() - # Validate options. - if opts.token_source and not opts.token_plugin: parser.error("You must use --token-plugin to use --token-source") @@ -583,11 +583,9 @@ def websockify_init(): if opts.legacy_syslog and not opts.syslog: parser.error("You must use --syslog to use --legacy-syslog") - opts.ssl_options = select_ssl_version(opts.ssl_version) del opts.ssl_version - if opts.log_file: # Setup logging to user-specified file. opts.log_file = os.path.abspath(opts.log_file) @@ -638,7 +636,6 @@ def websockify_init(): root = logging.getLogger() root.setLevel(logging.DEBUG) - # Transform to absolute path as daemon may chdir if opts.target_cfg: opts.target_cfg = os.path.abspath(opts.target_cfg) @@ -795,7 +792,6 @@ class LibProxyServer(ThreadingMixIn, HTTPServer): super().__init__((listen_host, listen_port), RequestHandlerClass) - def process_request(self, request, client_address): """Override process_request to implement a counter""" self.handler_id += 1 diff --git a/websockify/websocketserver.py b/websockify/websocketserver.py index c5bab61..146aa81 100644 --- a/websockify/websocketserver.py +++ b/websockify/websocketserver.py @@ -12,6 +12,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from websockify.websocket import WebSocket + class HttpWebSocket(WebSocket): """Class to glue websocket and http request functionality together""" def __init__(self, request_handler): @@ -100,11 +101,14 @@ class WebSocketRequestHandlerMixIn: """ pass + # Convenient ready made classes + class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler): pass + class WebSocketServer(HTTPServer): pass diff --git a/websockify/websockifyserver.py b/websockify/websockifyserver.py index 95cb451..b656cc1 100644 --- a/websockify/websockifyserver.py +++ b/websockify/websockifyserver.py @@ -32,6 +32,7 @@ if sys.platform == 'win32': from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError from websockify.websocketserver import WebSocketRequestHandlerMixIn + class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): def select_subprotocol(self, protocols): # Handle old websockify clients that still specify a sub-protocol @@ -40,6 +41,7 @@ class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): else: return '' + # HTTP handler with WebSocket upgrade support class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler): """ @@ -513,7 +515,7 @@ class WebSockifyServer(): @staticmethod def daemonize(keepfd=None, chdir='/'): - + if keepfd is None: keepfd = [] @@ -644,10 +646,10 @@ class WebSockifyServer(): """ Same as msg() but as warning. """ self.logger.log(logging.WARN, *args, **kwargs) - # # Events that can/should be overridden in sub-classes # + def started(self): """ Called after WebSockets startup """ self.vmsg("WebSockets server started") @@ -879,5 +881,3 @@ class WebSockifyServer(): # Restore signals for sig, func in original_signals.items(): signal.signal(sig, func) - -