From 3eca8ad195299eab0cb40712f9b6872b6497c0ca Mon Sep 17 00:00:00 2001 From: Doctor Date: Wed, 12 Mar 2025 20:55:57 +0300 Subject: [PATCH] run `ruff format .` --- setup.py | 80 ++--- tests/echo.py | 45 ++- tests/echo_client.py | 31 +- tests/latency.py | 2 +- tests/load.py | 74 +++-- tests/test_auth_plugins.py | 25 +- tests/test_token_plugins.py | 283 +++++++++-------- tests/test_websocket.py | 267 ++++++++++------ tests/test_websocketproxy.py | 29 +- tests/test_websocketserver.py | 3 +- tests/test_websockifyserver.py | 300 +++++++++++------- websockify.py | 2 +- websockify/__main__.py | 2 +- websockify/auth_plugins.py | 42 ++- websockify/sysloghandler.py | 62 ++-- websockify/token_plugins.py | 101 +++--- websockify/websocket.py | 203 +++++++----- websockify/websocketproxy.py | 563 ++++++++++++++++++++------------- websockify/websocketserver.py | 26 +- websockify/websockifyserver.py | 375 ++++++++++++---------- 20 files changed, 1488 insertions(+), 1027 deletions(-) diff --git a/setup.py b/setup.py index 5f8118f..294fec2 100644 --- a/setup.py +++ b/setup.py @@ -1,44 +1,44 @@ from setuptools import setup, find_packages -version = '0.13.0' -name = 'websockify' -long_description = open("README.md").read() + "\n" + \ - open("CHANGES.txt").read() + "\n" +version = "0.13.0" +name = "websockify" +long_description = open("README.md").read() + "\n" + open("CHANGES.txt").read() + "\n" -setup(name=name, - version=version, - description="Websockify.", - long_description=long_description, - long_description_content_type="text/markdown", - classifiers=[ - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - keywords='noVNC websockify', - license='LGPLv3', - url="https://github.com/novnc/websockify", - author="Joel Martin", - author_email="github@martintribe.org", - - packages=['websockify'], - include_package_data=True, - install_requires=[ - 'numpy', 'requests', - 'jwcrypto', - 'redis', - ], - zip_safe=False, - entry_points={ - 'console_scripts': [ - 'websockify = websockify.websocketproxy:websockify_init', +setup( + name=name, + version=version, + description="Websockify.", + long_description=long_description, + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + keywords="noVNC websockify", + license="LGPLv3", + url="https://github.com/novnc/websockify", + author="Joel Martin", + author_email="github@martintribe.org", + packages=["websockify"], + include_package_data=True, + install_requires=[ + "numpy", + "requests", + "jwcrypto", + "redis", + ], + zip_safe=False, + entry_points={ + "console_scripts": [ + "websockify = websockify.websocketproxy:websockify_init", ] - }, - ) + }, +) diff --git a/tests/echo.py b/tests/echo.py index 780891c..adcff5a 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -''' +""" A WebSocket server that echos back whatever it receives from the client. Copyright 2010 Joel Martin Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) @@ -8,16 +8,19 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) You can make a cert/key with openssl using: openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem as taken from http://docs.python.org/dev/library/ssl.html#certificates -''' +""" import os, sys, select, optparse, logging -sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler + class WebSocketEcho(WebSockifyRequestHandler): """ WebSockets server that echos back whatever is received from the - client. """ + client.""" + buffer_size = 8096 def new_websocket_client(self): @@ -33,9 +36,11 @@ class WebSocketEcho(WebSockifyRequestHandler): while True: wlist = [] - if cqueue or c_pend: wlist.append(self.request) + if cqueue or c_pend: + wlist.append(self.request) ins, outs, excepts = select.select(rlist, wlist, [], 1) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") if self.request in outs: # Send queued target data to the client @@ -50,20 +55,27 @@ class WebSocketEcho(WebSockifyRequestHandler): if closed: break -if __name__ == '__main__': + +if __name__ == "__main__": parser = optparse.OptionParser(usage="%prog [options] listen_port") - parser.add_option("--verbose", "-v", action="store_true", - help="verbose messages and per frame traffic") - parser.add_option("--cert", default="self.pem", - help="SSL certificate file") - parser.add_option("--key", default=None, - help="SSL key file (if separate from cert)") - parser.add_option("--ssl-only", action="store_true", - help="disallow non-encrypted connections") + parser.add_option( + "--verbose", + "-v", + action="store_true", + help="verbose messages and per frame traffic", + ) + parser.add_option("--cert", default="self.pem", help="SSL certificate file") + parser.add_option( + "--key", default=None, help="SSL key file (if separate from cert)" + ) + parser.add_option( + "--ssl-only", action="store_true", help="disallow non-encrypted connections" + ) (opts, args) = parser.parse_args() try: - if len(args) != 1: raise ValueError + if len(args) != 1: + raise ValueError opts.listen_port = int(args[0]) except ValueError: parser.error("Invalid arguments") @@ -73,4 +85,3 @@ if __name__ == '__main__': opts.web = "." server = WebSockifyServer(WebSocketEcho, **opts.__dict__) server.start_server() - diff --git a/tests/echo_client.py b/tests/echo_client.py index 4f238f6..f5fa962 100755 --- a/tests/echo_client.py +++ b/tests/echo_client.py @@ -5,9 +5,12 @@ import sys import optparse import select -sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocket, \ - WebSocketWantReadError, WebSocketWantWriteError +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from websockify.websocket import ( + WebSocket, + WebSocketWantReadError, + WebSocketWantWriteError, +) parser = optparse.OptionParser(usage="%prog URL") (opts, args) = parser.parse_args() @@ -22,19 +25,23 @@ print("Connecting to %s..." % URL) sock.connect(URL) print("Connected.") + def send(msg): while True: try: sock.sendmsg(msg) break except WebSocketWantReadError: - msg = '' + msg = "" ins, outs, excepts = select.select([sock], [], []) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") except WebSocketWantWriteError: - msg = '' + msg = "" ins, outs, excepts = select.select([], [sock], []) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") + def read(): while True: @@ -42,10 +49,13 @@ def read(): return sock.recvmsg() except WebSocketWantReadError: ins, outs, excepts = select.select([sock], [], []) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") except WebSocketWantWriteError: ins, outs, excepts = select.select([], [sock], []) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") + counter = 1 while True: @@ -56,7 +66,8 @@ while True: while True: ins, outs, excepts = select.select([sock], [], [], 1.0) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") if ins == []: break diff --git a/tests/latency.py b/tests/latency.py index 3ae4d96..9372c2a 120000 --- a/tests/latency.py +++ b/tests/latency.py @@ -1 +1 @@ -echo.py \ No newline at end of file +echo.py diff --git a/tests/load.py b/tests/load.py index 710b593..60dad78 100755 --- a/tests/load.py +++ b/tests/load.py @@ -1,32 +1,32 @@ #!/usr/bin/env python -''' +""" WebSocket server-side load test program. Sends and receives traffic that has a random payload (length and content) that is checksummed and given a sequence number. Any errors are reported and counted. -''' +""" import sys, os, select, random, time, optparse, logging -sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler -class WebSocketLoadServer(WebSockifyServer): +class WebSocketLoadServer(WebSockifyServer): recv_cnt = 0 send_cnt = 0 def __init__(self, *args, **kwargs): - self.delay = kwargs.pop('delay') + self.delay = kwargs.pop("delay") WebSockifyServer.__init__(self, *args, **kwargs) class WebSocketLoad(WebSockifyRequestHandler): - max_packet_size = 10000 def new_websocket_client(self): - print "Prepopulating random array" + print("Prepopulating random array") self.rand_array = [] for i in range(0, self.max_packet_size): self.rand_array.append(random.randint(0, 9)) @@ -37,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler): self.responder(self.request) - print "accumulated errors:", self.errors + print("accumulated errors:", self.errors) self.errors = 0 def responder(self, client): @@ -49,7 +49,8 @@ class WebSocketLoad(WebSockifyRequestHandler): while True: ins, outs, excepts = select.select(socks, socks, socks, 1) - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") if client in ins: frames, closed = self.recv_frames() @@ -57,7 +58,7 @@ class WebSocketLoad(WebSockifyRequestHandler): err = self.check(frames) if err: self.errors = self.errors + 1 - print err + print(err) if closed: break @@ -73,24 +74,22 @@ class WebSocketLoad(WebSockifyRequestHandler): def generate(self): length = random.randint(10, self.max_packet_size) - numlist = self.rand_array[self.max_packet_size-length:] + numlist = self.rand_array[self.max_packet_size - length :] # Error in length - #numlist.append(5) + # numlist.append(5) chksum = sum(numlist) # Error in checksum - #numlist[0] = 5 - nums = "".join( [str(n) for n in numlist] ) + # numlist[0] = 5 + nums = "".join([str(n) for n in numlist]) data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums) self.send_cnt += 1 return data - def check(self, frames): - err = "" for data in frames: - if data.count('$') > 1: + if data.count("$") > 1: raise Exception("Multiple parts within single packet") if len(data) == 0: self.traffic("_") @@ -101,12 +100,12 @@ class WebSocketLoad(WebSockifyRequestHandler): continue try: - cnt, length, chksum, nums = data[1:-1].split(':') - cnt = int(cnt) + cnt, length, chksum, nums = data[1:-1].split(":") + cnt = int(cnt) length = int(length) chksum = int(chksum) except ValueError: - print "\n" + repr(data) + "" + print("\n" + repr(data) + "") err += "Invalid data format\n" continue @@ -131,27 +130,37 @@ class WebSocketLoad(WebSockifyRequestHandler): real_chksum += int(num) if real_chksum != chksum: - err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum) + err += "Expected checksum %d but real chksum is %d\n" % ( + chksum, + real_chksum, + ) return err -if __name__ == '__main__': +if __name__ == "__main__": parser = optparse.OptionParser(usage="%prog [options] listen_port") - parser.add_option("--verbose", "-v", action="store_true", - help="verbose messages and per frame traffic") - parser.add_option("--cert", default="self.pem", - help="SSL certificate file") - parser.add_option("--key", default=None, - help="SSL key file (if separate from cert)") - parser.add_option("--ssl-only", action="store_true", - help="disallow non-encrypted connections") + parser.add_option( + "--verbose", + "-v", + action="store_true", + help="verbose messages and per frame traffic", + ) + parser.add_option("--cert", default="self.pem", help="SSL certificate file") + parser.add_option( + "--key", default=None, help="SSL key file (if separate from cert)" + ) + parser.add_option( + "--ssl-only", action="store_true", help="disallow non-encrypted connections" + ) (opts, args) = parser.parse_args() try: - if len(args) != 1: raise ValueError + if len(args) != 1: + raise ValueError opts.listen_port = int(args[0]) - if len(args) not in [1,2]: raise ValueError + if len(args) not in [1, 2]: + raise ValueError opts.listen_port = int(args[0]) if len(args) == 2: opts.delay = int(args[1]) @@ -165,4 +174,3 @@ if __name__ == '__main__': opts.web = "." server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server.start_server() - diff --git a/tests/test_auth_plugins.py b/tests/test_auth_plugins.py index 4b3bfb5..e3d348a 100644 --- a/tests/test_auth_plugins.py +++ b/tests/test_auth_plugins.py @@ -1,28 +1,33 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -""" Unit tests for Authentication plugins""" +"""Unit tests for Authentication plugins""" from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError import unittest class BasicHTTPAuthTestCase(unittest.TestCase): - def setUp(self): - self.plugin = BasicHTTPAuth('Aladdin:open sesame') + self.plugin = BasicHTTPAuth("Aladdin:open sesame") def test_no_auth(self): headers = {} - self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + self.assertRaises( + AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234" + ) def test_invalid_password(self): - headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'} - self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + headers = {"Authorization": "Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0"} + self.assertRaises( + AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234" + ) def test_valid_password(self): - headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='} - self.plugin.authenticate(headers, 'localhost', '1234') + headers = {"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="} + self.plugin.authenticate(headers, "localhost", "1234") def test_garbage_auth(self): - headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'} - self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + headers = {"Authorization": "Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx"} + self.assertRaises( + AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234" + ) diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index a9fd256..fae783b 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -1,80 +1,93 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -""" Unit tests for Token plugins""" +"""Unit tests for Token plugins""" import sys import unittest from unittest.mock import patch, mock_open, MagicMock from jwcrypto import jwt, jwk -from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis +from websockify.token_plugins import ( + parse_source_args, + ReadOnlyTokenFile, + JWTTokenApi, + TokenRedis, +) + class ParseSourceArgumentsTestCase(unittest.TestCase): def test_parameterized(self): params = [ - ('', ['']), - (':', ['', '']), - ('::', ['', '', '']), + ("", [""]), + (":", ["", ""]), + ("::", ["", "", ""]), ('"', ['"']), ('""', ['""']), ('"""', ['"""']), - ('"localhost"', ['localhost']), - ('"localhost":', ['localhost', '']), - ('"localhost"::', ['localhost', '', '']), - ('"local:host"', ['local:host']), - ('"local:host:"pass"', ['"local', 'host', "pass"]), - ('"local":"host"', ['local', 'host']), - ('"local":host"', ['local', 'host"']), - ('localhost:6379:1:pass"word:"my-app-namespace:dev"', - ['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']), + ('"localhost"', ["localhost"]), + ('"localhost":', ["localhost", ""]), + ('"localhost"::', ["localhost", "", ""]), + ('"local:host"', ["local:host"]), + ('"local:host:"pass"', ['"local', "host", "pass"]), + ('"local":"host"', ["local", "host"]), + ('"local":host"', ["local", 'host"']), + ( + 'localhost:6379:1:pass"word:"my-app-namespace:dev"', + ["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"], + ), ] for src, args in params: self.assertEqual(args, parse_source_args(src)) + class ReadOnlyTokenFileTestCase(unittest.TestCase): - patch('os.path.isdir', MagicMock(return_value=False)) + patch("os.path.isdir", MagicMock(return_value=False)) + def test_empty(self): - plugin = ReadOnlyTokenFile('configfile') + plugin = ReadOnlyTokenFile("configfile") config = "" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - pyopen.assert_called_once_with('configfile') + pyopen.assert_called_once_with("configfile") self.assertIsNone(result) - patch('os.path.isdir', MagicMock(return_value=False)) + patch("os.path.isdir", MagicMock(return_value=False)) + def test_simple(self): - plugin = ReadOnlyTokenFile('configfile') + plugin = ReadOnlyTokenFile("configfile") config = "testhost: remote_host:remote_port" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - pyopen.assert_called_once_with('configfile') + pyopen.assert_called_once_with("configfile") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") - patch('os.path.isdir', MagicMock(return_value=False)) + patch("os.path.isdir", MagicMock(return_value=False)) + def test_tabs(self): - plugin = ReadOnlyTokenFile('configfile') + plugin = ReadOnlyTokenFile("configfile") config = "testhost:\tremote_host:remote_port" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - pyopen.assert_called_once_with('configfile') + pyopen.assert_called_once_with("configfile") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") + class JWSTokenTestCase(unittest.TestCase): def test_asymmetric_jws_token_plugin(self): plugin = JWTTokenApi("./tests/fixtures/public.pem") @@ -82,7 +95,9 @@ class JWSTokenTestCase(unittest.TestCase): key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) + jwt_token = jwt.JWT( + {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} + ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) @@ -97,21 +112,26 @@ class JWSTokenTestCase(unittest.TestCase): key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) + jwt_token = jwt.JWT( + {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} + ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNone(result) - @patch('time.time') + @patch("time.time") def test_jwt_valid_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) + jwt_token = jwt.JWT( + {"alg": "RS256"}, + {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, + ) jwt_token.make_signed_token(key) mock_time.return_value = 150 @@ -121,14 +141,17 @@ class JWSTokenTestCase(unittest.TestCase): self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") - @patch('time.time') + @patch("time.time") def test_jwt_early_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) + jwt_token = jwt.JWT( + {"alg": "RS256"}, + {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, + ) jwt_token.make_signed_token(key) mock_time.return_value = 50 @@ -136,14 +159,17 @@ class JWSTokenTestCase(unittest.TestCase): self.assertIsNone(result) - @patch('time.time') + @patch("time.time") def test_jwt_late_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) + jwt_token = jwt.JWT( + {"alg": "RS256"}, + {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, + ) jwt_token.make_signed_token(key) mock_time.return_value = 250 @@ -156,8 +182,10 @@ class JWSTokenTestCase(unittest.TestCase): secret = open("./tests/fixtures/symmetric.key").read() key = jwk.JWK() - key.import_key(kty="oct",k=secret) - jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) + key.import_key(kty="oct", k=secret) + jwt_token = jwt.JWT( + {"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"} + ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) @@ -171,8 +199,10 @@ class JWSTokenTestCase(unittest.TestCase): secret = open("./tests/fixtures/symmetric.key").read() key = jwk.JWK() - key.import_key(kty="oct",k=secret) - jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) + key.import_key(kty="oct", k=secret) + jwt_token = jwt.JWT( + {"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"} + ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) @@ -188,10 +218,14 @@ class JWSTokenTestCase(unittest.TestCase): public_key_data = open("./tests/fixtures/public.pem", "rb").read() private_key.import_from_pem(private_key_data) public_key.import_from_pem(public_key_data) - jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) + jwt_token = jwt.JWT( + {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} + ) jwt_token.make_signed_token(private_key) - jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"}, - claims=jwt_token.serialize()) + jwe_token = jwt.JWT( + header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"}, + claims=jwt_token.serialize(), + ) jwe_token.make_encrypted_token(public_key) result = plugin.lookup(jwt_token.serialize()) @@ -200,103 +234,104 @@ class JWSTokenTestCase(unittest.TestCase): self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") + class TokenRedisTestCase(unittest.TestCase): def setUp(self): try: import redis except ImportError: - patcher = patch.dict(sys.modules, {'redis': MagicMock()}) + patcher = patch.dict(sys.modules, {"redis": MagicMock()}) patcher.start() self.addCleanup(patcher.stop) - @patch('redis.Redis') + @patch("redis.Redis") def test_empty(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = None - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNone(result) - @patch('redis.Redis') + @patch("redis.Redis") def test_simple(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b'{"host": "remote_host:remote_port"}' - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") - @patch('redis.Redis') + @patch("redis.Redis") def test_json_token_with_spaces(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b' {"host": "remote_host:remote_port"} ' - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") - @patch('redis.Redis') + @patch("redis.Redis") def test_text_token(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value - instance.get.return_value = b'remote_host:remote_port' + instance.get.return_value = b"remote_host:remote_port" - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") - @patch('redis.Redis') + @patch("redis.Redis") def test_text_token_with_spaces(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value - instance.get.return_value = b' remote_host:remote_port ' + instance.get.return_value = b" remote_host:remote_port " - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") - @patch('redis.Redis') + @patch("redis.Redis") def test_invalid_token(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b'{"host": "remote_host:remote_port" ' - result = plugin.lookup('testhost') + result = plugin.lookup("testhost") - instance.get.assert_called_once_with('testhost') + instance.get.assert_called_once_with("testhost") self.assertIsNone(result) - @patch('redis.Redis') + @patch("redis.Redis") def test_token_without_namespace(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234') - token = 'testhost' + plugin = TokenRedis("127.0.0.1:1234") + token = "testhost" def mock_redis_get(key): self.assertEqual(key, token) - return b'remote_host:remote_port' + return b"remote_host:remote_port" instance = mock_redis.return_value instance.get = mock_redis_get @@ -304,17 +339,17 @@ class TokenRedisTestCase(unittest.TestCase): result = plugin.lookup(token) self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") - @patch('redis.Redis') + @patch("redis.Redis") def test_token_with_namespace(self, mock_redis): - plugin = TokenRedis('127.0.0.1:1234:::namespace') - token = 'testhost' + plugin = TokenRedis("127.0.0.1:1234:::namespace") + token = "testhost" def mock_redis_get(key): self.assertEqual(key, "namespace:" + token) - return b'remote_host:remote_port' + return b"remote_host:remote_port" instance = mock_redis.return_value instance.get = mock_redis_get @@ -322,103 +357,103 @@ class TokenRedisTestCase(unittest.TestCase): result = plugin.lookup(token) self.assertIsNotNone(result) - self.assertEqual(result[0], 'remote_host') - self.assertEqual(result[1], 'remote_port') + self.assertEqual(result[0], "remote_host") + self.assertEqual(result[1], "remote_port") def test_src_only_host(self): - plugin = TokenRedis('127.0.0.1') + plugin = TokenRedis("127.0.0.1") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port(self): - plugin = TokenRedis('127.0.0.1:1234') + plugin = TokenRedis("127.0.0.1:1234") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db(self): - plugin = TokenRedis('127.0.0.1:1234:2') + plugin = TokenRedis("127.0.0.1:1234:2") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db_pass(self): - plugin = TokenRedis('127.0.0.1:1234:2:verysecret') + plugin = TokenRedis("127.0.0.1:1234:2:verysecret") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) - self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db_pass_namespace(self): - plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace') + plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) - self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "namespace:") def test_src_with_host_empty_port_empty_db_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1:::verysecret') + plugin = TokenRedis("127.0.0.1:::verysecret") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) - self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self): - plugin = TokenRedis('127.0.0.1::::') + plugin = TokenRedis("127.0.0.1::::") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1:::') + plugin = TokenRedis("127.0.0.1:::") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1::') + plugin = TokenRedis("127.0.0.1::") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1:') + plugin = TokenRedis("127.0.0.1:") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self): - plugin = TokenRedis('127.0.0.1::::namespace') + plugin = TokenRedis("127.0.0.1::::namespace") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) @@ -427,43 +462,43 @@ class TokenRedisTestCase(unittest.TestCase): def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self): plugin = TokenRedis('127.0.0.1::::"ns1:ns2"') - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "ns1:ns2:") def test_src_with_host_empty_port_db_no_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1::2') + plugin = TokenRedis("127.0.0.1::2") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_empty_db_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1:1234::verysecret') + plugin = TokenRedis("127.0.0.1:1234::verysecret") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) - self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_db_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1::2:verysecret') + plugin = TokenRedis("127.0.0.1::2:verysecret") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) - self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_db_empty_pass_no_namespace(self): - plugin = TokenRedis('127.0.0.1::2:') + plugin = TokenRedis("127.0.0.1::2:") - self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 8ee44f9..4570646 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -14,199 +14,268 @@ # License for the specific language governing permissions and limitations # under the License. -""" Unit tests for websocket """ +"""Unit tests for websocket""" + import unittest from websockify import websocket + class FakeSocket: def __init__(self): - self.data = b'' + self.data = b"" def send(self, buf): self.data += buf return len(buf) + class AcceptTestCase(unittest.TestCase): def test_success(self): ws = websocket.WebSocket() sock = FakeSocket() - ws.accept(sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') - self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data) - self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data) - self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data) + ws.accept( + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) + self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ") + self.assertTrue(b"\r\nUpgrade: websocket\r\n" in sock.data) + self.assertTrue(b"\r\nConnection: Upgrade\r\n" in sock.data) + self.assertTrue( + b"\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n" in sock.data + ) def test_bad_version(self): ws = websocket.WebSocket() sock = FakeSocket() - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '5', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '20', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertRaises( + Exception, + ws.accept, + sock, + {"upgrade": "websocket", "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q=="}, + ) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "5", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "20", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) def test_bad_upgrade(self): ws = websocket.WebSocket() sock = FakeSocket() - self.assertRaises(Exception, ws.accept, - sock, {'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket2', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "upgrade": "websocket2", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) def test_missing_key(self): ws = websocket.WebSocket() sock = FakeSocket() - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13'}) + self.assertRaises( + Exception, + ws.accept, + sock, + {"upgrade": "websocket", "Sec-WebSocket-Version": "13"}, + ) def test_protocol(self): class ProtoSocket(websocket.WebSocket): def select_subprotocol(self, protocol): - return 'gazonk' + return "gazonk" ws = ProtoSocket() sock = FakeSocket() - ws.accept(sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', - 'Sec-WebSocket-Protocol': 'foobar gazonk'}) - self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') - self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data) + ws.accept( + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + "Sec-WebSocket-Protocol": "foobar gazonk", + }, + ) + self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ") + self.assertTrue(b"\r\nSec-WebSocket-Protocol: gazonk\r\n" in sock.data) def test_no_protocol(self): ws = websocket.WebSocket() sock = FakeSocket() - ws.accept(sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') - self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data) + ws.accept( + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) + self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ") + self.assertFalse(b"\r\nSec-WebSocket-Protocol:" in sock.data) def test_missing_protocol(self): ws = websocket.WebSocket() sock = FakeSocket() - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', - 'Sec-WebSocket-Protocol': 'foobar gazonk'}) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + "Sec-WebSocket-Protocol": "foobar gazonk", + }, + ) def test_protocol(self): class ProtoSocket(websocket.WebSocket): def select_subprotocol(self, protocol): - return 'oddball' + return "oddball" ws = ProtoSocket() sock = FakeSocket() - self.assertRaises(Exception, ws.accept, - sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', - 'Sec-WebSocket-Protocol': 'foobar gazonk'}) + self.assertRaises( + Exception, + ws.accept, + sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + "Sec-WebSocket-Protocol": "foobar gazonk", + }, + ) + class PingPongTest(unittest.TestCase): def setUp(self): self.ws = websocket.WebSocket() self.sock = FakeSocket() - self.ws.accept(self.sock, {'upgrade': 'websocket', - 'Sec-WebSocket-Version': '13', - 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) - self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ') - self.sock.data = b'' + self.ws.accept( + self.sock, + { + "upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==", + }, + ) + self.assertEqual(self.sock.data[:13], b"HTTP/1.1 101 ") + self.sock.data = b"" def test_ping(self): self.ws.ping() - self.assertEqual(self.sock.data, b'\x89\x00') + self.assertEqual(self.sock.data, b"\x89\x00") def test_pong(self): self.ws.pong() - self.assertEqual(self.sock.data, b'\x8a\x00') + self.assertEqual(self.sock.data, b"\x8a\x00") def test_ping_data(self): - self.ws.ping(b'foo') - self.assertEqual(self.sock.data, b'\x89\x03foo') + self.ws.ping(b"foo") + self.assertEqual(self.sock.data, b"\x89\x03foo") def test_pong_data(self): - self.ws.pong(b'foo') - self.assertEqual(self.sock.data, b'\x8a\x03foo') + self.ws.pong(b"foo") + self.assertEqual(self.sock.data, b"\x8a\x03foo") + class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_decode_hybi_text(self): - buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' + buf = b"\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58" ws = websocket.WebSocket() res = ws._decode_hybi(buf) - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x1) - self.assertEqual(res['masked'], True) - self.assertEqual(res['length'], len(buf)) - self.assertEqual(res['payload'], b'Hello') + self.assertEqual(res["fin"], 1) + self.assertEqual(res["opcode"], 0x1) + self.assertEqual(res["masked"], True) + self.assertEqual(res["length"], len(buf)) + self.assertEqual(res["payload"], b"Hello") def test_decode_hybi_binary(self): - buf = b'\x82\x04\x01\x02\x03\x04' + buf = b"\x82\x04\x01\x02\x03\x04" ws = websocket.WebSocket() res = ws._decode_hybi(buf) - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], len(buf)) - self.assertEqual(res['payload'], b'\x01\x02\x03\x04') + self.assertEqual(res["fin"], 1) + self.assertEqual(res["opcode"], 0x2) + self.assertEqual(res["length"], len(buf)) + self.assertEqual(res["payload"], b"\x01\x02\x03\x04") def test_decode_hybi_extended_16bit_binary(self): - data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 - buf = b'\x82\x7e\x01\x04' + data + data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260 + buf = b"\x82\x7e\x01\x04" + data ws = websocket.WebSocket() res = ws._decode_hybi(buf) - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], len(buf)) - self.assertEqual(res['payload'], data) + self.assertEqual(res["fin"], 1) + self.assertEqual(res["opcode"], 0x2) + self.assertEqual(res["length"], len(buf)) + self.assertEqual(res["payload"], data) def test_decode_hybi_extended_64bit_binary(self): - data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 - buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data + data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260 + buf = b"\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04" + data ws = websocket.WebSocket() res = ws._decode_hybi(buf) - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], len(buf)) - self.assertEqual(res['payload'], data) + self.assertEqual(res["fin"], 1) + self.assertEqual(res["opcode"], 0x2) + self.assertEqual(res["length"], len(buf)) + self.assertEqual(res["payload"], data) def test_decode_hybi_multi(self): - buf1 = b'\x01\x03\x48\x65\x6c' - buf2 = b'\x80\x02\x6c\x6f' + buf1 = b"\x01\x03\x48\x65\x6c" + buf2 = b"\x80\x02\x6c\x6f" ws = websocket.WebSocket() res1 = ws._decode_hybi(buf1) - self.assertEqual(res1['fin'], 0) - self.assertEqual(res1['opcode'], 0x1) - self.assertEqual(res1['length'], len(buf1)) - self.assertEqual(res1['payload'], b'Hel') + self.assertEqual(res1["fin"], 0) + self.assertEqual(res1["opcode"], 0x1) + self.assertEqual(res1["length"], len(buf1)) + self.assertEqual(res1["payload"], b"Hel") res2 = ws._decode_hybi(buf2) - self.assertEqual(res2['fin'], 1) - self.assertEqual(res2['opcode'], 0x0) - self.assertEqual(res2['length'], len(buf2)) - self.assertEqual(res2['payload'], b'lo') + self.assertEqual(res2["fin"], 1) + self.assertEqual(res2["opcode"], 0x0) + self.assertEqual(res2["length"], len(buf2)) + self.assertEqual(res2["payload"], b"lo") def test_encode_hybi_basic(self): ws = websocket.WebSocket() - res = ws._encode_hybi(0x1, b'Hello') - expected = b'\x81\x05\x48\x65\x6c\x6c\x6f' + res = ws._encode_hybi(0x1, b"Hello") + expected = b"\x81\x05\x48\x65\x6c\x6c\x6f" self.assertEqual(res, expected) diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index e56bef7..fc69040 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -""" Unit tests for websocketproxy """ +"""Unit tests for websocketproxy""" import sys import unittest @@ -30,7 +30,7 @@ from websockify import auth_plugins class FakeSocket: - def __init__(self, data=b''): + def __init__(self, data=b""): self._data = data def recv(self, amt, flags=None): @@ -40,11 +40,11 @@ class FakeSocket: return res - def makefile(self, mode='r', buffsize=None): - if 'b' in mode: + def makefile(self, mode="r", buffsize=None): + if "b" in mode: return BytesIO(self._data) else: - return StringIO(self._data.decode('latin_1')) + return StringIO(self._data.decode("latin_1")) class FakeServer: @@ -58,14 +58,16 @@ class FakeServer: self.ssl_target = None self.unix_target = None + class ProxyRequestHandlerTestCase(unittest.TestCase): def setUp(self): super().setUp() self.handler = websocketproxy.ProxyRequestHandler( - FakeSocket(), "127.0.0.1", FakeServer()) + FakeSocket(), "127.0.0.1", FakeServer() + ) self.handler.path = "https://localhost:6080/websockify?token=blah" self.handler.headers = {} - patch('websockify.websockifyserver.WebSockifyServer.socket').start() + patch("websockify.websockifyserver.WebSockifyServer.socket").start() def tearDown(self): patch.stopall() @@ -76,8 +78,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): def lookup(self, token): return ("some host", "some port") - host, port = self.handler.get_target( - TestPlugin(None)) + host, port = self.handler.get_target(TestPlugin(None)) self.assertEqual(host, "some host") self.assertEqual(port, "some port") @@ -87,8 +88,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): def lookup(self, token): return ("unix_socket", "/tmp/socket") - _, socket = self.handler.get_target( - TestPlugin(None)) + _, socket = self.handler.get_target(TestPlugin(None)) self.assertEqual(socket, "/tmp/socket") @@ -100,11 +100,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): with self.assertRaises(FakeServer.EClose): self.handler.get_target(TestPlugin(None)) - @patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock()) + @patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock()) def test_token_plugin(self): class TestPlugin(token_plugins.BasePlugin): def lookup(self, token): - return (self.source + token).split(',') + return (self.source + token).split(",") self.handler.server.token_plugin = TestPlugin("somehost,") self.handler.validate_connection() @@ -112,7 +112,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): self.assertEqual(self.handler.server.target_host, "somehost") self.assertEqual(self.handler.server.target_port, "blah") - @patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock()) + @patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock()) def test_auth_plugin(self): class TestPlugin(auth_plugins.BasePlugin): def authenticate(self, headers, target_host, target_port): @@ -128,4 +128,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): self.handler.server.target_host = "someotherhost" self.handler.auth_connection() - diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py index 0e37e3d..a1732e6 100644 --- a/tests/test_websocketserver.py +++ b/tests/test_websocketserver.py @@ -1,5 +1,5 @@ +"""Unit tests for websocketserver""" -""" Unit tests for websocketserver """ import unittest from unittest.mock import patch, MagicMock @@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase): # Then req_obj.end_headers.assert_called_once_with() - diff --git a/tests/test_websockifyserver.py b/tests/test_websockifyserver.py index 42edb4a..0c09126 100644 --- a/tests/test_websockifyserver.py +++ b/tests/test_websockifyserver.py @@ -14,7 +14,8 @@ # License for the specific language governing permissions and limitations # under the License. -""" Unit tests for websockifyserver """ +"""Unit tests for websockifyserver""" + import errno import os import logging @@ -36,11 +37,11 @@ from websockify import websockifyserver def raise_oserror(*args, **kwargs): - raise OSError('fake error') + raise OSError("fake error") class FakeSocket: - def __init__(self, data=b''): + def __init__(self, data=b""): self._data = data def recv(self, amt, flags=None): @@ -50,19 +51,19 @@ class FakeSocket: return res - def makefile(self, mode='r', buffsize=None): - if 'b' in mode: + def makefile(self, mode="r", buffsize=None): + if "b" in mode: return BytesIO(self._data) else: - return StringIO(self._data.decode('latin_1')) + return StringIO(self._data.decode("latin_1")) class WebSockifyRequestHandlerTestCase(unittest.TestCase): def setUp(self): super().setUp() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') + self.tmpdir = tempfile.mkdtemp("-websockify-tests") # Mock this out cause it screws tests up - patch('os.chdir').start() + patch("os.chdir").start() def tearDown(self): """Called automatically after each test.""" @@ -70,31 +71,41 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase): os.rmdir(self.tmpdir) super().tearDown() - def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, - **kwargs): - web = kwargs.pop('web', self.tmpdir) + def _get_server( + self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs + ): + web = kwargs.pop("web", self.tmpdir) return websockifyserver.WebSockifyServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=web, - record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, - **kwargs) + handler_class, + listen_host="localhost", + listen_port=80, + key=self.tmpdir, + web=web, + record=self.tmpdir, + daemon=False, + ssl_only=0, + idle_timeout=1, + **kwargs, + ) - @patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error') + @patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error") def test_normal_get_with_only_upgrade_returns_error(self, send_error): server = self._get_server(web=None) handler = websockifyserver.WebSockifyRequestHandler( - FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) + FakeSocket(b"GET /tmp.txt HTTP/1.1"), "127.0.0.1", server + ) handler.do_GET() send_error.assert_called_with(405) - @patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error') + @patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error") def test_list_dir_with_file_only_returns_error(self, send_error): server = self._get_server(file_only=True) handler = websockifyserver.WebSockifyRequestHandler( - FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server) + FakeSocket(b"GET / HTTP/1.1"), "127.0.0.1", server + ) - handler.path = '/' + handler.path = "/" handler.do_GET() send_error.assert_called_with(404) @@ -102,9 +113,9 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase): class WebSockifyServerTestCase(unittest.TestCase): def setUp(self): super().setUp() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') + self.tmpdir = tempfile.mkdtemp("-websockify-tests") # Mock this out cause it screws tests up - patch('os.chdir').start() + patch("os.chdir").start() def tearDown(self): """Called automatically after each test.""" @@ -112,32 +123,38 @@ class WebSockifyServerTestCase(unittest.TestCase): os.rmdir(self.tmpdir) super().tearDown() - def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, - **kwargs): + def _get_server( + self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs + ): return websockifyserver.WebSockifyServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=self.tmpdir, - record=self.tmpdir, **kwargs) + handler_class, + listen_host="localhost", + listen_port=80, + key=self.tmpdir, + web=self.tmpdir, + record=self.tmpdir, + **kwargs, + ) def test_daemonize_raises_error_while_closing_fds(self): server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - patch('os.fork').start().return_value = 0 - patch('signal.signal').start() - patch('os.setsid').start() - patch('os.close').start().side_effect = raise_oserror - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + patch("os.fork").start().return_value = 0 + patch("signal.signal").start() + patch("os.setsid").start() + patch("os.close").start().side_effect = raise_oserror + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./") def test_daemonize_ignores_ebadf_error_while_closing_fds(self): def raise_oserror_ebadf(fd): - raise OSError(errno.EBADF, 'fake error') + raise OSError(errno.EBADF, "fake error") server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - patch('os.fork').start().return_value = 0 - patch('signal.signal').start() - patch('os.setsid').start() - patch('os.close').start().side_effect = raise_oserror_ebadf - patch('os.open').start().side_effect = raise_oserror - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + patch("os.fork").start().return_value = 0 + patch("signal.signal").start() + patch("os.setsid").start() + patch("os.close").start().side_effect = raise_oserror_ebadf + patch("os.open").start().side_effect = raise_oserror + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./") def test_handshake_fails_on_not_ready(self): server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) @@ -145,23 +162,29 @@ class WebSockifyServerTestCase(unittest.TestCase): def fake_select(rlist, wlist, xlist, timeout=None): return ([], [], []) - patch('select.select').start().side_effect = fake_select + patch("select.select").start().side_effect = fake_select self.assertRaises( - websockifyserver.WebSockifyServer.EClose, server.do_handshake, - FakeSocket(), '127.0.0.1') + websockifyserver.WebSockifyServer.EClose, + server.do_handshake, + FakeSocket(), + "127.0.0.1", + ) def test_empty_handshake_fails(self): server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - sock = FakeSocket('') + sock = FakeSocket("") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) - patch('select.select').start().side_effect = fake_select + patch("select.select").start().side_effect = fake_select self.assertRaises( - websockifyserver.WebSockifyServer.EClose, server.do_handshake, - sock, '127.0.0.1') + websockifyserver.WebSockifyServer.EClose, + server.do_handshake, + sock, + "127.0.0.1", + ) def test_handshake_policy_request(self): # TODO(directxman12): implement @@ -170,35 +193,39 @@ class WebSockifyServerTestCase(unittest.TestCase): def test_handshake_ssl_only_without_ssl_raises_error(self): server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - sock = FakeSocket(b'some initial data') + sock = FakeSocket(b"some initial data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) - patch('select.select').start().side_effect = fake_select + patch("select.select").start().side_effect = fake_select self.assertRaises( - websockifyserver.WebSockifyServer.EClose, server.do_handshake, - sock, '127.0.0.1') + websockifyserver.WebSockifyServer.EClose, + server.do_handshake, + sock, + "127.0.0.1", + ) def test_do_handshake_no_ssl(self): class FakeHandler: CALLED = False + def __init__(self, *args, **kwargs): type(self).CALLED = True FakeHandler.CALLED = False server = self._get_server( - handler_class=FakeHandler, daemon=True, - ssl_only=0, idle_timeout=1) + handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1 + ) - sock = FakeSocket(b'some initial data') + sock = FakeSocket(b"some initial data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) - patch('select.select').start().side_effect = fake_select - self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) + patch("select.select").start().side_effect = fake_select + self.assertEqual(server.do_handshake(sock, "127.0.0.1"), sock) self.assertTrue(FakeHandler.CALLED, True) def test_do_handshake_ssl(self): @@ -210,18 +237,22 @@ class WebSockifyServerTestCase(unittest.TestCase): pass def test_do_handshake_ssl_without_cert_raises_error(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, - cert='afdsfasdafdsafdsafdsafdas') + server = self._get_server( + daemon=True, ssl_only=0, idle_timeout=1, cert="afdsfasdafdsafdsafdsafdas" + ) sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) - patch('select.select').start().side_effect = fake_select + patch("select.select").start().side_effect = fake_select self.assertRaises( - websockifyserver.WebSockifyServer.EClose, server.do_handshake, - sock, '127.0.0.1') + websockifyserver.WebSockifyServer.EClose, + server.do_handshake, + sock, + "127.0.0.1", + ) def test_do_handshake_ssl_error_eof_raises_close_error(self): server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) @@ -234,58 +265,79 @@ class WebSockifyServerTestCase(unittest.TestCase): def fake_wrap_socket(*args, **kwargs): raise ssl.SSLError(ssl.SSL_ERROR_EOF) - class fake_create_default_context(): + class fake_create_default_context: def __init__(self, purpose): self.verify_mode = None self.options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): raise ssl.SSLError(ssl.SSL_ERROR_EOF) - patch('select.select').start().side_effect = fake_select - patch('ssl.create_default_context').start().side_effect = fake_create_default_context + patch("select.select").start().side_effect = fake_select + patch( + "ssl.create_default_context" + ).start().side_effect = fake_create_default_context self.assertRaises( - websockifyserver.WebSockifyServer.EClose, server.do_handshake, - sock, '127.0.0.1') + websockifyserver.WebSockifyServer.EClose, + server.do_handshake, + sock, + "127.0.0.1", + ) def test_do_handshake_ssl_sets_ciphers(self): - test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2' + test_ciphers = "TEST-CIPHERS-1:TEST-CIPHER-2" class FakeHandler: def __init__(self, *args, **kwargs): pass - server = self._get_server(handler_class=FakeHandler, daemon=True, - idle_timeout=1, ssl_ciphers=test_ciphers) + server = self._get_server( + handler_class=FakeHandler, + daemon=True, + idle_timeout=1, + ssl_ciphers=test_ciphers, + ) sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) - class fake_create_default_context(): - CIPHERS = '' + class fake_create_default_context: + CIPHERS = "" + def __init__(self, purpose): self.verify_mode = None self.options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): pass + def set_ciphers(self, ciphers_to_set): fake_create_default_context.CIPHERS = ciphers_to_set - patch('select.select').start().side_effect = fake_select - patch('ssl.create_default_context').start().side_effect = fake_create_default_context - server.do_handshake(sock, '127.0.0.1') + patch("select.select").start().side_effect = fake_select + patch( + "ssl.create_default_context" + ).start().side_effect = fake_create_default_context + server.do_handshake(sock, "127.0.0.1") self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) def test_do_handshake_ssl_sets_opions(self): @@ -295,8 +347,12 @@ class WebSockifyServerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): pass - server = self._get_server(handler_class=FakeHandler, daemon=True, - idle_timeout=1, ssl_options=test_options) + server = self._get_server( + handler_class=FakeHandler, + daemon=True, + idle_timeout=1, + ssl_options=test_options, + ) sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): @@ -304,26 +360,36 @@ class WebSockifyServerTestCase(unittest.TestCase): class fake_create_default_context: OPTIONS = 0 + def __init__(self, purpose): self.verify_mode = None self._options = 0 + def load_cert_chain(self, certfile, keyfile, password): pass + def set_default_verify_paths(self): pass + def load_verify_locations(self, cafile): pass + def wrap_socket(self, *args, **kwargs): pass + def get_options(self): return self._options + def set_options(self, val): fake_create_default_context.OPTIONS = val + options = property(get_options, set_options) - patch('select.select').start().side_effect = fake_select - patch('ssl.create_default_context').start().side_effect = fake_create_default_context - server.do_handshake(sock, '127.0.0.1') + patch("select.select").start().side_effect = fake_select + patch( + "ssl.create_default_context" + ).start().side_effect = fake_create_default_context + server.do_handshake(sock, "127.0.0.1") self.assertEqual(fake_create_default_context.OPTIONS, test_options) def test_fallback_sigchld_handler(self): @@ -332,38 +398,38 @@ class WebSockifyServerTestCase(unittest.TestCase): def test_start_server_error(self): server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) - sock = server.socket('localhost') + sock = server.socket("localhost") def fake_select(rlist, wlist, xlist, timeout=None): raise Exception("fake error") - patch('websockify.websockifyserver.WebSockifyServer.socket').start() - patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() - patch('select.select').start().side_effect = fake_select + patch("websockify.websockifyserver.WebSockifyServer.socket").start() + patch("websockify.websockifyserver.WebSockifyServer.daemonize").start() + patch("select.select").start().side_effect = fake_select server.start_server() def test_start_server_keyboardinterrupt(self): server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') + sock = server.socket("localhost") def fake_select(rlist, wlist, xlist, timeout=None): raise KeyboardInterrupt - patch('websockify.websockifyserver.WebSockifyServer.socket').start() - patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() - patch('select.select').start().side_effect = fake_select + patch("websockify.websockifyserver.WebSockifyServer.socket").start() + patch("websockify.websockifyserver.WebSockifyServer.daemonize").start() + patch("select.select").start().side_effect = fake_select server.start_server() def test_start_server_systemexit(self): server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') + sock = server.socket("localhost") def fake_select(rlist, wlist, xlist, timeout=None): sys.exit() - patch('websockify.websockifyserver.WebSockifyServer.socket').start() - patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() - patch('select.select').start().side_effect = fake_select + patch("websockify.websockifyserver.WebSockifyServer.socket").start() + patch("websockify.websockifyserver.WebSockifyServer.daemonize").start() + patch("select.select").start().side_effect = fake_select server.start_server() def test_socket_set_keepalive_options(self): @@ -372,29 +438,37 @@ class WebSockifyServerTestCase(unittest.TestCase): keepintvl = 56 server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = server.socket( + "localhost", + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl, + ) - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) + if hasattr(socket, "TCP_KEEPCNT"): + self.assertEqual( + sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt + ) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle) + self.assertEqual( + sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl + ) - sock = server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = server.socket( + "localhost", + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl, + ) - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) + if hasattr(socket, "TCP_KEEPCNT"): + self.assertNotEqual( + sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt + ) + self.assertNotEqual( + sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle + ) + self.assertNotEqual( + sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl + ) diff --git a/websockify.py b/websockify.py index e5224d5..f5bdd21 120000 --- a/websockify.py +++ b/websockify.py @@ -1 +1 @@ -run \ No newline at end of file +run diff --git a/websockify/__main__.py b/websockify/__main__.py index 8378d46..c16794f 100644 --- a/websockify/__main__.py +++ b/websockify/__main__.py @@ -1,4 +1,4 @@ import websockify -if __name__ == '__main__': +if __name__ == "__main__": websockify.websocketproxy.websockify_init() diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 36fac52..b08bdfb 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -1,4 +1,4 @@ -class BasePlugin(): +class BasePlugin: def __init__(self, src=None): self.source = src @@ -7,7 +7,9 @@ class BasePlugin(): class AuthenticationError(Exception): - def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None): + def __init__( + self, log_msg=None, response_code=403, response_headers={}, response_msg=None + ): self.code = response_code self.headers = response_headers self.msg = response_msg @@ -15,7 +17,7 @@ class AuthenticationError(Exception): if log_msg is None: log_msg = response_msg - super().__init__('%s %s' % (self.code, log_msg)) + super().__init__("%s %s" % (self.code, log_msg)) class InvalidOriginError(AuthenticationError): @@ -24,12 +26,13 @@ class InvalidOriginError(AuthenticationError): self.actual_origin = actual super().__init__( - response_msg='Invalid Origin', + response_msg="Invalid Origin", log_msg="Invalid Origin Header: Expected one of " - "%s, got '%s'" % (expected, actual)) + "%s, got '%s'" % (expected, actual), + ) -class BasicHTTPAuth(): +class BasicHTTPAuth: """Verifies Basic Auth headers. Specify src as username:password""" def __init__(self, src=None): @@ -37,9 +40,10 @@ class BasicHTTPAuth(): def authenticate(self, headers, target_host, target_port): import base64 - auth_header = headers.get('Authorization') + + auth_header = headers.get("Authorization") if auth_header: - if not auth_header.startswith('Basic '): + if not auth_header.startswith("Basic "): self.auth_error() try: @@ -49,11 +53,11 @@ class BasicHTTPAuth(): try: # http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication - user_pass_as_text = user_pass_raw.decode('ISO-8859-1') + user_pass_as_text = user_pass_raw.decode("ISO-8859-1") except UnicodeDecodeError: self.auth_error() - user_pass = user_pass_as_text.split(':', 1) + user_pass = user_pass_as_text.split(":", 1) if len(user_pass) != 2: self.auth_error() @@ -64,7 +68,7 @@ class BasicHTTPAuth(): self.demand_auth() def validate_creds(self, username, password): - if '%s:%s' % (username, password) == self.src: + if "%s:%s" % (username, password) == self.src: return True else: return False @@ -73,10 +77,13 @@ class BasicHTTPAuth(): raise AuthenticationError(response_code=403) def demand_auth(self): - raise AuthenticationError(response_code=401, - response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + raise AuthenticationError( + response_code=401, + response_headers={"WWW-Authenticate": 'Basic realm="Websockify"'}, + ) -class ExpectOrigin(): + +class ExpectOrigin: def __init__(self, src=None): if src is None: self.source = [] @@ -84,11 +91,12 @@ class ExpectOrigin(): self.source = src.split() def authenticate(self, headers, target_host, target_port): - origin = headers.get('Origin', None) + origin = headers.get("Origin", None) if origin is None or origin not in self.source: raise InvalidOriginError(expected=self.source, actual=origin) -class ClientCertCNAuth(): + +class ClientCertCNAuth: """Verifies client by SSL certificate. Specify src as whitespace separated list of common names.""" def __init__(self, src=None): @@ -98,5 +106,5 @@ class ClientCertCNAuth(): self.source = src.split() def authenticate(self, headers, target_host, target_port): - if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source: + if headers.get("SSL_CLIENT_S_DN_CN", None) not in self.source: raise AuthenticationError(response_code=403) diff --git a/websockify/sysloghandler.py b/websockify/sysloghandler.py index 8b344a1..6abdf1b 100644 --- a/websockify/sysloghandler.py +++ b/websockify/sysloghandler.py @@ -7,23 +7,26 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): as defined by RFC 5424. """ - _legacy_head_fmt = '<{pri}>{ident}[{pid}]: ' - _rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - ' + _legacy_head_fmt = "<{pri}>{ident}[{pid}]: " + _rfc5424_head_fmt = "<{pri}>1 {timestamp} {hostname} {ident} {pid} - - " _head_fmt = _rfc5424_head_fmt _legacy = False - _timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ' + _timestamp_fmt = "%Y-%m-%dT%H:%M:%SZ" _max_hostname = 255 - _max_ident = 24 #safer for old daemons + _max_ident = 24 # safer for old daemons _send_length = False - _tail = '\n' - + _tail = "\n" ident = None - - def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT), - facility=handlers.SysLogHandler.LOG_USER, - socktype=None, ident=None, legacy=False): + def __init__( + self, + address=("localhost", handlers.SYSLOG_UDP_PORT), + facility=handlers.SysLogHandler.LOG_USER, + socktype=None, + ident=None, + legacy=False, + ): """ Initialize a handler. @@ -46,7 +49,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): super().__init__(address, facility, socktype) - def emit(self, record): """ Emit a record. @@ -57,46 +59,44 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): try: # Gather info. - text = self.format(record).replace(self._tail, ' ') - if not text: # nothing to log + text = self.format(record).replace(self._tail, " ") + if not text: # nothing to log return - pri = self.encodePriority(self.facility, - self.mapPriority(record.levelname)) + pri = self.encodePriority(self.facility, self.mapPriority(record.levelname)) - timestamp = time.strftime(self._timestamp_fmt, time.gmtime()); - - hostname = socket.gethostname()[:self._max_hostname] + timestamp = time.strftime(self._timestamp_fmt, time.gmtime()) + hostname = socket.gethostname()[: self._max_hostname] if self.ident: - ident = self.ident[:self._max_ident] + ident = self.ident[: self._max_ident] else: - ident = '' + ident = "" - pid = os.getpid() # shouldn't need truncation + pid = os.getpid() # shouldn't need truncation # Format the header. head = { - 'pri': pri, - 'timestamp': timestamp, - 'hostname': hostname, - 'ident': ident, - 'pid': pid, + "pri": pri, + "timestamp": timestamp, + "hostname": hostname, + "ident": ident, + "pid": pid, } - msg = self._head_fmt.format(**head).encode('ascii', 'ignore') + msg = self._head_fmt.format(**head).encode("ascii", "ignore") # Encode text as plain ASCII if possible, else use UTF-8 with BOM. try: - msg += text.encode('ascii') + msg += text.encode("ascii") except UnicodeEncodeError: - msg += text.encode('utf-8-sig') + msg += text.encode("utf-8-sig") # Add length or tail character, if necessary. if self.socktype != socket.SOCK_DGRAM: if self._send_length: - msg = ('%d ' % len(msg)).encode('ascii') + msg + msg = ("%d " % len(msg)).encode("ascii") + msg else: - msg += self._tail.encode('ascii') + msg += self._tail.encode("ascii") # Send the message. if self.unixsocket: diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index 5a95490..4c46d1b 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -10,8 +10,8 @@ logger = logging.getLogger(__name__) _SOURCE_SPLIT_REGEX = re.compile( r'(?<=^)"([^"]+)"(?=:|$)' r'|(?<=:)"([^"]+)"(?=:|$)' - r'|(?<=^)([^:]*)(?=:|$)' - r'|(?<=:)([^:]*)(?=:|$)', + r"|(?<=^)([^:]*)(?=:|$)" + r"|(?<=:)([^:]*)(?=:|$)", ) @@ -26,7 +26,7 @@ def parse_source_args(src): return [m[0] or m[1] or m[2] or m[3] for m in matches] -class BasePlugin(): +class BasePlugin: def __init__(self, src): self.source = src @@ -44,8 +44,7 @@ class ReadOnlyTokenFile(BasePlugin): def _load_targets(self): if os.path.isdir(self.source): - cfg_files = [os.path.join(self.source, f) for - f in os.listdir(self.source)] + cfg_files = [os.path.join(self.source, f) for f in os.listdir(self.source)] else: cfg_files = [self.source] @@ -53,12 +52,14 @@ class ReadOnlyTokenFile(BasePlugin): index = 1 for f in cfg_files: for line in [l.strip() for l in open(f).readlines()]: - if line and not line.startswith('#'): + if line and not line.startswith("#"): try: - tok, target = re.split(r':\s', line) - self._targets[tok] = target.strip().rsplit(':', 1) + tok, target = re.split(r":\s", line) + self._targets[tok] = target.strip().rsplit(":", 1) except ValueError: - logger.error("Syntax error in %s on line %d" % (self.source, index)) + logger.error( + "Syntax error in %s on line %d" % (self.source, index) + ) index += 1 def lookup(self, token): @@ -83,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile): return super().lookup(token) + class TokenFileName(BasePlugin): # source is a directory # token is filename @@ -91,12 +93,12 @@ class TokenFileName(BasePlugin): super().__init__(src) if not os.path.isdir(src): raise Exception("TokenFileName plugin requires a directory") - + def lookup(self, token): token = os.path.basename(token) path = os.path.join(self.source, token) if os.path.exists(path): - return open(path).read().strip().split(':') + return open(path).read().strip().split(":") else: return None @@ -109,9 +111,9 @@ class BaseTokenAPI(BasePlugin): # in this file can be used w/o unnecessary dependencies def process_result(self, resp): - host, port = resp.text.split(':') - port = port.encode('ascii','ignore') - return [ host, port ] + host, port = resp.text.split(":") + port = port.encode("ascii", "ignore") + return [host, port] def lookup(self, token): import requests @@ -130,7 +132,7 @@ class JSONTokenApi(BaseTokenAPI): def process_result(self, resp): resp_json = resp.json() - return (resp_json['host'], resp_json['port']) + return (resp_json["host"], resp_json["port"]) class JWTTokenApi(BasePlugin): @@ -145,7 +147,7 @@ class JWTTokenApi(BasePlugin): key = jwk.JWK() try: - with open(self.source, 'rb') as key_file: + with open(self.source, "rb") as key_file: key_data = key_file.read() except Exception as e: logger.error("Error loading key file: %s" % str(e)) @@ -155,39 +157,41 @@ class JWTTokenApi(BasePlugin): key.import_from_pem(key_data) except: try: - key.import_key(k=key_data.decode('utf-8'),kty='oct') + key.import_key(k=key_data.decode("utf-8"), kty="oct") except: - logger.error('Failed to correctly parse key data!') + logger.error("Failed to correctly parse key data!") return None try: token = jwt.JWT(key=key, jwt=token) parsed_header = json.loads(token.header) - if 'enc' in parsed_header: + if "enc" in parsed_header: # Token is encrypted, so we need to decrypt by passing the claims to a new instance token = jwt.JWT(key=key, jwt=token.claims) parsed = json.loads(token.claims) - if 'nbf' in parsed: + if "nbf" in parsed: # Not Before is present, so we need to check it - if time.time() < parsed['nbf']: - logger.warning('Token can not be used yet!') + if time.time() < parsed["nbf"]: + logger.warning("Token can not be used yet!") return None - if 'exp' in parsed: + if "exp" in parsed: # Expiration time is present, so we need to check it - if time.time() > parsed['exp']: - logger.warning('Token has expired!') + if time.time() > parsed["exp"]: + logger.warning("Token has expired!") return None - return (parsed['host'], parsed['port']) + return (parsed["host"], parsed["port"]) except Exception as e: logger.error("Failed to parse token: %s" % str(e)) return None except ImportError: - logger.error("package jwcrypto not found, are you sure you've installed it correctly?") + logger.error( + "package jwcrypto not found, are you sure you've installed it correctly?" + ) return None @@ -251,6 +255,7 @@ class TokenRedis(BasePlugin): pip install redis """ + def __init__(self, src): try: import redis @@ -285,7 +290,9 @@ class TokenRedis(BasePlugin): if not self._password: self._password = None elif len(fields) == 5: - self._server, self._port, self._db, self._password, self._namespace = fields + self._server, self._port, self._db, self._password, self._namespace = ( + fields + ) if not self._port: self._port = 6379 if not self._db: @@ -301,24 +308,30 @@ class TokenRedis(BasePlugin): if self._namespace: self._namespace += ":" - logger.info("TokenRedis backend initialized (%s:%s)" % - (self._server, self._port)) + logger.info( + "TokenRedis backend initialized (%s:%s)" % (self._server, self._port) + ) except ValueError: - logger.error("The provided --token-source='%s' is not in the " - "expected format [:[:[:[:]]]]" % - src) + logger.error( + "The provided --token-source='%s' is not in the " + "expected format [:[:[:[:]]]]" + % src + ) sys.exit() def lookup(self, token): try: import redis except ImportError: - logger.error("package redis not found, are you sure you've installed them correctly?") + logger.error( + "package redis not found, are you sure you've installed them correctly?" + ) sys.exit() logger.info("resolving token '%s'" % token) - client = redis.Redis(host=self._server, port=self._port, - db=self._db, password=self._password) + client = redis.Redis( + host=self._server, port=self._port, db=self._db, password=self._password + ) stuff = client.get(self._namespace + token) if stuff is None: return None @@ -330,14 +343,14 @@ class TokenRedis(BasePlugin): combo = json.loads(responseStr) host, port = combo["host"].split(":") except ValueError: - logger.error("Unable to decode JSON token: %s" % - responseStr) + logger.error("Unable to decode JSON token: %s" % responseStr) return None except KeyError: - logger.error("Unable to find 'host' key in JSON token: %s" % - responseStr) + logger.error( + "Unable to find 'host' key in JSON token: %s" % responseStr + ) return None - elif re.match(r'\S+:\S+', responseStr): + elif re.match(r"\S+:\S+", responseStr): host, port = responseStr.split(":") else: logger.error("Unable to parse token: %s" % responseStr) @@ -368,7 +381,7 @@ class UnixDomainSocketDirectory(BasePlugin): if not stat.S_ISSOCK(os.stat(uds_path).st_mode): return None - return [ 'unix_socket', uds_path ] + return ["unix_socket", uds_path] except Exception as e: - logger.error("Error finding unix domain socket: %s" % str(e)) - return None + logger.error("Error finding unix domain socket: %s" % str(e)) + return None diff --git a/websockify/websocket.py b/websockify/websocket.py index 1bed8cc..b249ee9 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -''' +""" Python WebSocket library Copyright 2011 Joel Martin Copyright 2016 Pierre Ossman @@ -10,7 +10,7 @@ Supports following protocol versions: - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07 - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 - http://tools.ietf.org/html/rfc6455 -''' +""" import sys import array @@ -28,14 +28,19 @@ try: import numpy except ImportError: import warnings + warnings.warn("no 'numpy' module, HyBi protocol will be slower") numpy = None + class WebSocketWantReadError(ssl.SSLWantReadError): pass + + class WebSocketWantWriteError(ssl.SSLWantWriteError): pass + class WebSocket: """WebSocket protocol socket like class. @@ -73,11 +78,11 @@ class WebSocket: self._state = "new" - self._partial_msg = b'' + self._partial_msg = b"" - self._recv_buffer = b'' + self._recv_buffer = b"" self._recv_queue = [] - self._send_buffer = b'' + self._send_buffer = b"" self._previous_sendmsg = None @@ -91,16 +96,22 @@ class WebSocket: def __getattr__(self, name): # These methods are just redirected to the underlying socket - if name in ["fileno", - "getpeername", "getsockname", - "getsockopt", "setsockopt", - "gettimeout", "settimeout", - "setblocking"]: + if name in [ + "fileno", + "getpeername", + "getsockname", + "getsockopt", + "setsockopt", + "gettimeout", + "settimeout", + "setblocking", + ]: assert self.socket is not None return getattr(self.socket, name) else: - raise AttributeError("%s instance has no attribute '%s'" % - (self.__class__.__name__, name)) + raise AttributeError( + "%s instance has no attribute '%s'" % (self.__class__.__name__, name) + ) def connect(self, uri, origin=None, protocols=[]): """Establishes a new connection to a WebSocket server. @@ -118,8 +129,7 @@ class WebSocket: connect() must retain the same arguments. """ - self.client = True; - + self.client = True uri = urlparse(uri) port = uri.port @@ -140,8 +150,9 @@ class WebSocket: if uri.scheme in ("wss", "https"): context = ssl.create_default_context() - self.socket = context.wrap_socket(self.socket, - server_hostname=uri.hostname) + self.socket = context.wrap_socket( + self.socket, server_hostname=uri.hostname + ) self._state = "ssl_handshake" else: self._state = "headers" @@ -151,7 +162,7 @@ class WebSocket: self._state = "headers" if self._state == "headers": - self._key = '' + self._key = "" for i in range(16): self._key += chr(random.randrange(256)) self._key = b64encode(self._key.encode("latin-1")).decode("ascii") @@ -184,10 +195,10 @@ class WebSocket: if not self._recv(): raise Exception("Socket closed unexpectedly") - if self._recv_buffer.find(b'\r\n\r\n') == -1: + if self._recv_buffer.find(b"\r\n\r\n") == -1: raise WebSocketWantReadError - (request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1) + (request, self._recv_buffer) = self._recv_buffer.split(b"\r\n", 1) request = request.decode("latin-1") words = request.split() @@ -196,17 +207,17 @@ class WebSocket: if words[1] != "101": raise Exception("WebSocket request denied: %s" % " ".join(words[1:])) - (headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1) - headers = headers.decode('latin-1') + '\r\n' + (headers, self._recv_buffer) = self._recv_buffer.split(b"\r\n\r\n", 1) + headers = headers.decode("latin-1") + "\r\n" headers = email.message_from_string(headers) if headers.get("Upgrade", "").lower() != "websocket": print(type(headers)) raise Exception("Missing or incorrect upgrade header") - accept = headers.get('Sec-WebSocket-Accept') + accept = headers.get("Sec-WebSocket-Accept") if accept is None: - raise Exception("Missing Sec-WebSocket-Accept header"); + raise Exception("Missing Sec-WebSocket-Accept header") expected = sha1((self._key + self.GUID).encode("ascii")).digest() expected = b64encode(expected).decode("ascii") @@ -214,9 +225,9 @@ class WebSocket: del self._key if accept != expected: - raise Exception("Invalid Sec-WebSocket-Accept header"); + raise Exception("Invalid Sec-WebSocket-Accept header") - self.protocol = headers.get('Sec-WebSocket-Protocol') + self.protocol = headers.get("Sec-WebSocket-Protocol") if len(protocols) == 0: if self.protocol is not None: raise Exception("Unexpected Sec-WebSocket-Protocol header") @@ -256,34 +267,34 @@ class WebSocket: if headers.get("upgrade", "").lower() != "websocket": raise Exception("Missing or incorrect upgrade header") - ver = headers.get('Sec-WebSocket-Version') + ver = headers.get("Sec-WebSocket-Version") if ver is None: - raise Exception("Missing Sec-WebSocket-Version header"); + raise Exception("Missing Sec-WebSocket-Version header") # HyBi-07 report version 7 # HyBi-08 - HyBi-12 report version 8 # HyBi-13 reports version 13 - if ver in ['7', '8', '13']: + if ver in ["7", "8", "13"]: self.version = "hybi-%02d" % int(ver) else: raise Exception("Unsupported protocol version %s" % ver) - key = headers.get('Sec-WebSocket-Key') + key = headers.get("Sec-WebSocket-Key") if key is None: - raise Exception("Missing Sec-WebSocket-Key header"); + raise Exception("Missing Sec-WebSocket-Key header") # Generate the hash value for the accept header accept = sha1((key + self.GUID).encode("ascii")).digest() accept = b64encode(accept).decode("ascii") - self.protocol = '' - protocols = headers.get('Sec-WebSocket-Protocol', '').split(',') + self.protocol = "" + protocols = headers.get("Sec-WebSocket-Protocol", "").split(",") if protocols: self.protocol = self.select_subprotocol(protocols) # We are required to choose one of the protocols # presented by the client if self.protocol not in protocols: - raise Exception('Invalid protocol selected') + raise Exception("Invalid protocol selected") self.send_response(101, "Switching Protocols") self.send_header("Upgrade", "websocket") @@ -461,7 +472,7 @@ class WebSocket: def send_request(self, type, path): self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path)) - def ping(self, data=b''): + def ping(self, data=b""): """Write a ping message to the WebSocket WebSocketWantWriteError can be raised if there is insufficient @@ -486,7 +497,7 @@ class WebSocket: self._previous_sendmsg = data raise - def pong(self, data=b''): + def pong(self, data=b""): """Write a pong message to the WebSocket WebSocketWantWriteError can be raised if there is insufficient @@ -540,7 +551,7 @@ class WebSocket: self._sent_close = True - msg = b'' + msg = b"" if code is not None: msg += struct.pack(">H", code) if reason is not None: @@ -602,7 +613,7 @@ class WebSocket: frame = self._decode_hybi(self._recv_buffer) if frame is None: break - self._recv_buffer = self._recv_buffer[frame['length']:] + self._recv_buffer = self._recv_buffer[frame["length"] :] self._recv_queue.append(frame) return True @@ -612,29 +623,39 @@ class WebSocket: while self._recv_queue: frame = self._recv_queue.pop(0) - if not self.client and not frame['masked']: - self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked") + if not self.client and not frame["masked"]: + self.shutdown( + socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked" + ) continue - if self.client and frame['masked']: + if self.client and frame["masked"]: self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked") continue if frame["opcode"] == 0x0: if not self._partial_msg: - self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame") + self.shutdown( + socket.SHUT_RDWR, + 1002, + "Procotol error: Unexpected continuation frame", + ) continue self._partial_msg += frame["payload"] if frame["fin"]: msg = self._partial_msg - self._partial_msg = b'' + self._partial_msg = b"" return msg elif frame["opcode"] == 0x1: - self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported") + self.shutdown( + socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported" + ) elif frame["opcode"] == 0x2: if self._partial_msg: - self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame") + self.shutdown( + socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame" + ) continue if frame["fin"]: @@ -652,7 +673,9 @@ class WebSocket: return None if not frame["fin"]: - self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close") + self.shutdown( + socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close" + ) continue code = None @@ -664,7 +687,11 @@ class WebSocket: try: reason = reason.decode("UTF-8") except UnicodeDecodeError: - self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close") + self.shutdown( + socket.SHUT_RDWR, + 1002, + "Procotol error: Invalid UTF-8 in close", + ) continue if code is None: @@ -679,18 +706,26 @@ class WebSocket: return None elif frame["opcode"] == 0x9: if not frame["fin"]: - self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping") + self.shutdown( + socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping" + ) continue self.handle_ping(frame["payload"]) elif frame["opcode"] == 0xA: if not frame["fin"]: - self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong") + self.shutdown( + socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong" + ) continue self.handle_pong(frame["payload"]) else: - self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"]) + self.shutdown( + socket.SHUT_RDWR, + 1003, + "Unsupported: Unknown opcode 0x%02x" % frame["opcode"], + ) raise WebSocketWantReadError @@ -731,7 +766,7 @@ class WebSocket: def _sendmsg(self, opcode, msg): # Sends a standard data message if self.client: - mask = b'' + mask = b"" for i in range(4): mask += random.randrange(256).to_bytes() frame = self._encode_hybi(opcode, msg, mask) @@ -755,35 +790,36 @@ class WebSocket: plen = len(buf) pstart = 0 pend = plen - b = c = b'' + b = c = b"" if plen >= 4: - dtype=numpy.dtype('') + dtype = numpy.dtype("") mask = numpy.frombuffer(mask, dtype, count=1) data = numpy.frombuffer(buf, dtype, count=int(plen / 4)) - #b = numpy.bitwise_xor(data, mask).data + # b = numpy.bitwise_xor(data, mask).data b = numpy.bitwise_xor(data, mask).tobytes() if plen % 4: - dtype=numpy.dtype('B') - if sys.byteorder == 'big': - dtype = dtype.newbyteorder('>') + dtype = numpy.dtype("B") + if sys.byteorder == "big": + dtype = dtype.newbyteorder(">") mask = numpy.frombuffer(mask, dtype, count=(plen % 4)) - data = numpy.frombuffer(buf, dtype, - offset=plen - (plen % 4), count=(plen % 4)) + data = numpy.frombuffer( + buf, dtype, offset=plen - (plen % 4), count=(plen % 4) + ) c = numpy.bitwise_xor(data, mask).tobytes() return b + c else: # Slower fallback - data = array.array('B') + data = array.array("B") data.frombytes(buf) for i in range(len(data)): data[i] ^= mask[i % 4] return data.tobytes() def _encode_hybi(self, opcode, buf, mask_key=None, fin=True): - """ Encode a HyBi style WebSocket frame. + """Encode a HyBi style WebSocket frame. Optional opcode: 0x0 - continuation 0x1 - text frame @@ -793,7 +829,7 @@ class WebSocket: 0xA - pong """ - b1 = opcode & 0x0f + b1 = opcode & 0x0F if fin: b1 |= 0x80 @@ -804,11 +840,11 @@ class WebSocket: payload_len = len(buf) if payload_len <= 125: - header = struct.pack('>BB', b1, payload_len | mask_bit) + header = struct.pack(">BB", b1, payload_len | mask_bit) elif payload_len > 125 and payload_len < 65536: - header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len) + header = struct.pack(">BBH", b1, 126 | mask_bit, payload_len) elif payload_len >= 65536: - header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len) + header = struct.pack(">BBQ", b1, 127 | mask_bit, payload_len) if mask_key is not None: return header + mask_key + buf @@ -816,7 +852,7 @@ class WebSocket: return header + buf def _decode_hybi(self, buf): - """ Decode HyBi style WebSocket packets. + """Decode HyBi style WebSocket packets. Returns: {'fin' : boolean, 'opcode' : number, @@ -825,11 +861,7 @@ class WebSocket: 'payload' : decoded_buffer} """ - f = {'fin' : 0, - 'opcode' : 0, - 'masked' : False, - 'length' : 0, - 'payload' : None} + f = {"fin": 0, "opcode": 0, "masked": False, "length": 0, "payload": None} blen = len(buf) hlen = 2 @@ -838,39 +870,38 @@ class WebSocket: return None b1, b2 = struct.unpack(">BB", buf[:2]) - f['opcode'] = b1 & 0x0f - f['fin'] = not not (b1 & 0x80) - f['masked'] = not not (b2 & 0x80) + f["opcode"] = b1 & 0x0F + f["fin"] = not not (b1 & 0x80) + f["masked"] = not not (b2 & 0x80) - if f['masked']: + if f["masked"]: hlen += 4 if blen < hlen: return None - length = b2 & 0x7f + length = b2 & 0x7F if length == 126: hlen += 2 if blen < hlen: return None - length, = struct.unpack('>H', buf[2:4]) + (length,) = struct.unpack(">H", buf[2:4]) elif length == 127: hlen += 8 if blen < hlen: return None - length, = struct.unpack('>Q', buf[2:10]) + (length,) = struct.unpack(">Q", buf[2:10]) - f['length'] = hlen + length + f["length"] = hlen + length - if blen < f['length']: + if blen < f["length"]: return None - if f['masked']: + if f["masked"]: # unmask payload - mask_key = buf[hlen-4:hlen] - f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key) + mask_key = buf[hlen - 4 : hlen] + f["payload"] = self._unmask(buf[hlen : (hlen + length)], mask_key) else: - f['payload'] = buf[hlen:(hlen+length)] + f["payload"] = buf[hlen : (hlen + length)] return f - diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 5afec81..8076d03 100644 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -''' +""" A WebSocket to TCP socket proxy with support for "wss://" encryption. Copyright 2011 Joel Martin Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) @@ -9,7 +9,7 @@ You can make a cert/key with openssl using: openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem as taken from http://docs.python.org/dev/library/ssl.html#certificates -''' +""" import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat from socketserver import ThreadingMixIn @@ -20,8 +20,8 @@ from websockify import websockifyserver from websockify import auth_plugins as auth from urllib.parse import parse_qs, urlparse -class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler): +class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler): buffer_size = 65536 traffic_legend = """ @@ -38,7 +38,7 @@ Traffic Legend: def send_auth_error(self, ex): self.send_response(ex.code, ex.msg) - self.send_header('Content-Type', 'text/html') + self.send_header("Content-Type", "text/html") for name, val in ex.headers.items(): self.send_header(name, val) @@ -49,7 +49,7 @@ Traffic Legend: return host, port = self.get_target(self.server.token_plugin) - if host == 'unix_socket': + if host == "unix_socket": self.server.unix_target = port else: @@ -62,7 +62,7 @@ Traffic Legend: # clear out any existing SSL_ headers that the client might # have maliciously set - ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ] + ssl_headers = [h for h in self.headers if h.startswith("SSL_")] for h in ssl_headers: del self.headers[h] @@ -70,19 +70,21 @@ Traffic Legend: # get client certificate data client_cert_data = self.request.getpeercert() # extract subject information - client_cert_subject = client_cert_data['subject'] + client_cert_subject = client_cert_data["subject"] # flatten data structure client_cert_subject = dict([x[0] for x in client_cert_subject]) # add common name to headers (apache +StdEnvVars style) - self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName'] + self.headers["SSL_CLIENT_S_DN_CN"] = client_cert_subject["commonName"] except (TypeError, AttributeError, KeyError): # not a SSL connection or client presented no certificate with valid data pass try: self.server.auth_plugin.authenticate( - headers=self.headers, target_host=self.server.target_host, - target_port=self.server.target_port) + headers=self.headers, + target_host=self.server.target_host, + target_port=self.server.target_port, + ) except auth.AuthenticationError: ex = sys.exc_info()[1] self.send_auth_error(ex) @@ -96,26 +98,37 @@ Traffic Legend: # Connect to the target if self.server.wrap_cmd: - msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) + msg = "connecting to command: '%s' (port %s)" % ( + " ".join(self.server.wrap_cmd), + self.server.target_port, + ) elif self.server.unix_target: msg = "connecting to unix socket: %s" % self.server.unix_target else: msg = "connecting to: %s:%s" % ( - self.server.target_host, self.server.target_port) + self.server.target_host, + self.server.target_port, + ) if self.server.ssl_target: msg += " (using SSL)" self.log_message(msg) try: - tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host, - self.server.target_port, - connect=True, - use_ssl=self.server.ssl_target, - unix_socket=self.server.unix_target) + tsock = websockifyserver.WebSockifyServer.socket( + self.server.target_host, + self.server.target_port, + connect=True, + use_ssl=self.server.ssl_target, + unix_socket=self.server.unix_target, + ) except Exception as e: - self.log_message("Failed to connect to %s:%s: %s", - self.server.target_host, self.server.target_port, e) + self.log_message( + "Failed to connect to %s:%s: %s", + self.server.target_host, + self.server.target_port, + e, + ) raise self.CClose(1011, "Failed to connect to downstream server") # Option unavailable when listening to unix socket @@ -134,8 +147,11 @@ Traffic Legend: tsock.shutdown(socket.SHUT_RDWR) tsock.close() if self.verbose: - self.log_message("%s:%s: Closed target", - self.server.target_host, self.server.target_port) + self.log_message( + "%s:%s: Closed target", + self.server.target_host, + self.server.target_port, + ) def get_target(self, target_plugin): """ @@ -149,20 +165,20 @@ Traffic Legend: if self.host_token: # Use hostname as token - token = self.headers.get('Host') + token = self.headers.get("Host") # Remove port from hostname, as it'll always be the one where # websockify listens (unless something between the client and # websockify is redirecting traffic, but that's beside the point) if token: - token = token.partition(':')[0] + token = token.partition(":")[0] else: # Extract the token parameter from url - args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url + args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url - if 'token' in args and len(args['token']): - token = args['token'][0].rstrip('\n') + if "token" in args and len(args["token"]): + token = args["token"][0].rstrip("\n") else: token = None @@ -200,13 +216,15 @@ Traffic Legend: self.heartbeat = now + self.server.heartbeat self.send_ping() - if tqueue: wlist.append(target) - if cqueue or c_pend: wlist.append(self.request) + if tqueue: + wlist.append(target) + if cqueue or c_pend: + wlist.append(self.request) try: ins, outs, excepts = select.select(rlist, wlist, [], 1) except OSError: exc = sys.exc_info()[1] - if hasattr(exc, 'errno'): + if hasattr(exc, "errno"): err = exc.errno else: err = exc[0] @@ -216,7 +234,8 @@ Traffic Legend: else: continue - if excepts: raise Exception("Socket exception") + if excepts: + raise Exception("Socket exception") if self.request in outs: # Send queued target data to the client @@ -230,8 +249,7 @@ Traffic Legend: tqueue.extend(bufs) if closed: - - while (len(tqueue) != 0): + while len(tqueue) != 0: # Send queued client data to the target dat = tqueue.pop(0) sent = target.send(dat) @@ -244,10 +262,12 @@ Traffic Legend: # TODO: What about blocking on client socket? if self.verbose: - self.log_message("%s:%s: Client closed connection", - self.server.target_host, self.server.target_port) - raise self.CClose(closed['code'], closed['reason']) - + self.log_message( + "%s:%s: Client closed connection", + self.server.target_host, + self.server.target_port, + ) + raise self.CClose(closed["code"], closed["reason"]) if target in outs: # Send queued client data to the target @@ -260,29 +280,31 @@ Traffic Legend: tqueue.insert(0, dat[sent:]) self.print_traffic(".>") - if target in ins: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: - # Target socket closed, flushing queues and closing client-side websocket # Send queued target data to the client if len(cqueue) != 0: c_pend = True - while(c_pend): + while c_pend: c_pend = self.send_frames(cqueue) cqueue = [] if self.verbose: - self.log_message("%s:%s: Target closed connection", - self.server.target_host, self.server.target_port) + self.log_message( + "%s:%s: Target closed connection", + self.server.target_host, + self.server.target_port, + ) raise self.CClose(1000, "Target closed") cqueue.append(buf) self.print_traffic("{") + class WebSocketProxy(websockifyserver.WebSockifyServer): """ Proxy traffic to and from a WebSockets client to a normal TCP @@ -293,27 +315,29 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): # Save off proxy specific options - self.target_host = kwargs.pop('target_host', None) - self.target_port = kwargs.pop('target_port', None) - self.wrap_cmd = kwargs.pop('wrap_cmd', None) - self.wrap_mode = kwargs.pop('wrap_mode', None) - self.unix_target = kwargs.pop('unix_target', None) - self.ssl_target = kwargs.pop('ssl_target', None) - self.heartbeat = kwargs.pop('heartbeat', None) + self.target_host = kwargs.pop("target_host", None) + self.target_port = kwargs.pop("target_port", None) + self.wrap_cmd = kwargs.pop("wrap_cmd", None) + self.wrap_mode = kwargs.pop("wrap_mode", None) + self.unix_target = kwargs.pop("unix_target", None) + self.ssl_target = kwargs.pop("ssl_target", None) + self.heartbeat = kwargs.pop("heartbeat", None) - self.token_plugin = kwargs.pop('token_plugin', None) - self.host_token = kwargs.pop('host_token', None) - self.auth_plugin = kwargs.pop('auth_plugin', None) + self.token_plugin = kwargs.pop("token_plugin", None) + self.host_token = kwargs.pop("host_token", None) + self.auth_plugin = kwargs.pop("auth_plugin", None) # Last 3 timestamps command was run - self.wrap_times = [0, 0, 0] + self.wrap_times = [0, 0, 0] if self.wrap_cmd: wsdir = os.path.dirname(sys.argv[0]) - rebinder_path = [os.path.join(wsdir, "..", "lib"), - os.path.join(wsdir, "..", "lib", "websockify"), - os.path.join(wsdir, ".."), - wsdir] + rebinder_path = [ + os.path.join(wsdir, "..", "lib"), + os.path.join(wsdir, "..", "lib", "websockify"), + os.path.join(wsdir, ".."), + wsdir, + ] self.rebinder = None for rdir in rebinder_path: @@ -329,17 +353,22 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): self.target_host = "127.0.0.1" # Loopback # Find a free high port sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('', 0)) + sock.bind(("", 0)) self.target_port = sock.getsockname()[1] sock.close() # Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist - ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ]) + ld_preloads = filter( + None, [self.rebinder, os.environ.get("LD_PRELOAD", None)] + ) - os.environ.update({ - "LD_PRELOAD": os.pathsep.join(ld_preloads), - "REBIND_OLD_PORT": str(kwargs['listen_port']), - "REBIND_NEW_PORT": str(self.target_port)}) + os.environ.update( + { + "LD_PRELOAD": os.pathsep.join(ld_preloads), + "REBIND_OLD_PORT": str(kwargs["listen_port"]), + "REBIND_NEW_PORT": str(self.target_port), + } + ) super().__init__(RequestHandlerClass, *args, **kwargs) @@ -348,7 +377,8 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): self.wrap_times.append(time.time()) self.wrap_times.pop(0) self.cmd = subprocess.Popen( - self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) + self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup + ) self.spawn_message = True def started(self): @@ -371,10 +401,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): if self.token_plugin: msg = " - proxying from %s to targets generated by %s" % ( - src_string, type(self.token_plugin).__name__) + src_string, + type(self.token_plugin).__name__, + ) else: - msg = " - proxying from %s to %s" % ( - src_string, dst_string) + msg = " - proxying from %s to %s" % (src_string, dst_string) if self.ssl_target: msg += " (using SSL)" @@ -401,7 +432,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): sys.exit(ret) elif self.wrap_mode == "respawn": now = time.time() - avg = sum(self.wrap_times)/len(self.wrap_times) + avg = sum(self.wrap_times) / len(self.wrap_times) if (now - avg) < 10: # 3 times in the last 10 seconds if self.spawn_message: @@ -418,15 +449,25 @@ def _subprocess_setup(): SSL_OPTIONS = { - 'default': ssl.OP_ALL, - 'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | - ssl.OP_NO_TLSv1, - 'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | - ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1, - 'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | - ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2, + "default": ssl.OP_ALL, + "tlsv1_1": ssl.PROTOCOL_SSLv23 + | ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_TLSv1, + "tlsv1_2": ssl.PROTOCOL_SSLv23 + | ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1, + "tlsv1_3": ssl.PROTOCOL_SSLv23 + | ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_TLSv1_2, } + def select_ssl_version(version): """Returns SSL options for the most secure TSL version available on this Python version""" @@ -439,11 +480,11 @@ def select_ssl_version(version): keys.sort() fallback = keys[-1] logger = logging.getLogger(WebSocketProxy.log_prefix) - logger.warn("TLS version %s unsupported. Falling back to %s", - version, fallback) + logger.warn("TLS version %s unsupported. Falling back to %s", version, fallback) return SSL_OPTIONS[fallback] + def websockify_init(): # Setup basic logging to stderr. stderr_handler = logging.StreamHandler() @@ -464,105 +505,198 @@ def websockify_init(): usage += "\n %prog [options]" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" parser = optparse.OptionParser(usage=usage) - parser.add_option("--verbose", "-v", action="store_true", - help="verbose messages") - parser.add_option("--traffic", action="store_true", - help="per frame traffic") - parser.add_option("--record", - help="record sessions to FILE.[session_number]", metavar="FILE") - parser.add_option("--daemon", "-D", - dest="daemon", action="store_true", - help="become a daemon (background process)") - parser.add_option("--run-once", action="store_true", - help="handle a single WebSocket connection and exit") - parser.add_option("--timeout", type=int, default=0, - help="after TIMEOUT seconds exit when not connected") - parser.add_option("--idle-timeout", type=int, default=0, - help="server exits after TIMEOUT seconds if there are no " - "active connections") - parser.add_option("--cert", default="self.pem", - help="SSL certificate file") - parser.add_option("--key", default=None, - help="SSL key file (if separate from cert)") - parser.add_option("--key-password", default=None, - help="SSL key password") - parser.add_option("--ssl-only", action="store_true", - help="disallow non-encrypted client connections") - parser.add_option("--ssl-target", action="store_true", - help="connect to SSL target as SSL client") - parser.add_option("--verify-client", action="store_true", - help="require encrypted client to present a valid certificate " - "(needs Python 2.7.9 or newer or Python 3.4 or newer)") - parser.add_option("--cafile", metavar="FILE", - help="file of concatenated certificates of authorities trusted " - "for validating clients (only effective with --verify-client). " - "If omitted, system default list of CAs is used.") - parser.add_option("--ssl-version", type="choice", default="default", - choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store", - help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)") - parser.add_option("--ssl-ciphers", action="store", - help="list of ciphers allowed for connection. For a list of " - "supported ciphers run `openssl ciphers`") - parser.add_option("--unix-listen", - help="listen to unix socket", metavar="FILE", default=None) - parser.add_option("--unix-listen-mode", default=None, - help="specify mode for unix socket (defaults to 0600)") - parser.add_option("--unix-target", - help="connect to unix socket target", metavar="FILE") - parser.add_option("--inetd", - help="inetd mode, receive listening socket from stdin", action="store_true") - parser.add_option("--web", default=None, metavar="DIR", - help="run webserver on same port. Serve files from DIR.") - parser.add_option("--web-auth", action="store_true", - help="require authentication to access webserver.") - parser.add_option("--wrap-mode", default="exit", metavar="MODE", - choices=["exit", "ignore", "respawn"], - help="action to take when the wrapped program exits " - "or daemonizes: exit (default), ignore, respawn") - parser.add_option("--prefer-ipv6", "-6", - action="store_true", dest="source_is_ipv6", - help="prefer IPv6 when resolving source_addr") - parser.add_option("--libserver", action="store_true", - help="use Python library SocketServer engine") - parser.add_option("--target-config", metavar="FILE", - dest="target_cfg", - help="Configuration file containing valid targets " - "in the form 'token: host:port' or, alternatively, a " - "directory containing configuration files of this form " - "(DEPRECATED: use `--token-plugin TokenFile --token-source " - " path/to/token/file` instead)") - parser.add_option("--token-plugin", default=None, metavar="CLASS", - help="use a Python class, usually one from websockify.token_plugins, " - "such as TokenFile, to process tokens into host:port pairs") - parser.add_option("--token-source", default=None, metavar="ARG", - help="an argument to be passed to the token plugin " - "on instantiation") - parser.add_option("--host-token", action="store_true", - help="use the host HTTP header as token instead of the " - "token URL query parameter") - parser.add_option("--auth-plugin", default=None, metavar="CLASS", - help="use a Python class, usually one from websockify.auth_plugins, " - "such as BasicHTTPAuth, to determine if a connection is allowed") - parser.add_option("--auth-source", default=None, metavar="ARG", - help="an argument to be passed to the auth plugin " - "on instantiation") - parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL", - help="send a ping to the client every INTERVAL seconds") - parser.add_option("--log-file", metavar="FILE", - dest="log_file", - help="File where logs will be saved") - parser.add_option("--syslog", default=None, metavar="SERVER", - help="Log to syslog server. SERVER can be local socket, " - "such as /dev/log, or a UDP host:port pair.") - parser.add_option("--legacy-syslog", action="store_true", - help="Use the old syslog protocol instead of RFC 5424. " - "Use this if the messages produced by websockify seem abnormal.") - parser.add_option("--file-only", action="store_true", - help="use this to disable directory listings in web server.") + parser.add_option("--verbose", "-v", action="store_true", help="verbose messages") + parser.add_option("--traffic", action="store_true", help="per frame traffic") + parser.add_option( + "--record", help="record sessions to FILE.[session_number]", metavar="FILE" + ) + parser.add_option( + "--daemon", + "-D", + dest="daemon", + action="store_true", + help="become a daemon (background process)", + ) + parser.add_option( + "--run-once", + action="store_true", + help="handle a single WebSocket connection and exit", + ) + parser.add_option( + "--timeout", + type=int, + default=0, + help="after TIMEOUT seconds exit when not connected", + ) + parser.add_option( + "--idle-timeout", + type=int, + default=0, + help="server exits after TIMEOUT seconds if there are no active connections", + ) + parser.add_option("--cert", default="self.pem", help="SSL certificate file") + parser.add_option( + "--key", default=None, help="SSL key file (if separate from cert)" + ) + parser.add_option("--key-password", default=None, help="SSL key password") + parser.add_option( + "--ssl-only", + action="store_true", + help="disallow non-encrypted client connections", + ) + parser.add_option( + "--ssl-target", action="store_true", help="connect to SSL target as SSL client" + ) + parser.add_option( + "--verify-client", + action="store_true", + help="require encrypted client to present a valid certificate " + "(needs Python 2.7.9 or newer or Python 3.4 or newer)", + ) + parser.add_option( + "--cafile", + metavar="FILE", + help="file of concatenated certificates of authorities trusted " + "for validating clients (only effective with --verify-client). " + "If omitted, system default list of CAs is used.", + ) + parser.add_option( + "--ssl-version", + type="choice", + default="default", + choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], + action="store", + help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)", + ) + parser.add_option( + "--ssl-ciphers", + action="store", + help="list of ciphers allowed for connection. For a list of " + "supported ciphers run `openssl ciphers`", + ) + parser.add_option( + "--unix-listen", help="listen to unix socket", metavar="FILE", default=None + ) + parser.add_option( + "--unix-listen-mode", + default=None, + help="specify mode for unix socket (defaults to 0600)", + ) + parser.add_option( + "--unix-target", help="connect to unix socket target", metavar="FILE" + ) + parser.add_option( + "--inetd", + help="inetd mode, receive listening socket from stdin", + action="store_true", + ) + parser.add_option( + "--web", + default=None, + metavar="DIR", + help="run webserver on same port. Serve files from DIR.", + ) + parser.add_option( + "--web-auth", + action="store_true", + help="require authentication to access webserver.", + ) + parser.add_option( + "--wrap-mode", + default="exit", + metavar="MODE", + choices=["exit", "ignore", "respawn"], + help="action to take when the wrapped program exits " + "or daemonizes: exit (default), ignore, respawn", + ) + parser.add_option( + "--prefer-ipv6", + "-6", + action="store_true", + dest="source_is_ipv6", + help="prefer IPv6 when resolving source_addr", + ) + parser.add_option( + "--libserver", + action="store_true", + help="use Python library SocketServer engine", + ) + parser.add_option( + "--target-config", + metavar="FILE", + dest="target_cfg", + help="Configuration file containing valid targets " + "in the form 'token: host:port' or, alternatively, a " + "directory containing configuration files of this form " + "(DEPRECATED: use `--token-plugin TokenFile --token-source " + " path/to/token/file` instead)", + ) + parser.add_option( + "--token-plugin", + default=None, + metavar="CLASS", + help="use a Python class, usually one from websockify.token_plugins, " + "such as TokenFile, to process tokens into host:port pairs", + ) + parser.add_option( + "--token-source", + default=None, + metavar="ARG", + help="an argument to be passed to the token plugin on instantiation", + ) + parser.add_option( + "--host-token", + action="store_true", + help="use the host HTTP header as token instead of the " + "token URL query parameter", + ) + parser.add_option( + "--auth-plugin", + default=None, + metavar="CLASS", + help="use a Python class, usually one from websockify.auth_plugins, " + "such as BasicHTTPAuth, to determine if a connection is allowed", + ) + parser.add_option( + "--auth-source", + default=None, + metavar="ARG", + help="an argument to be passed to the auth plugin on instantiation", + ) + parser.add_option( + "--heartbeat", + type=int, + default=0, + metavar="INTERVAL", + help="send a ping to the client every INTERVAL seconds", + ) + parser.add_option( + "--log-file", + metavar="FILE", + dest="log_file", + help="File where logs will be saved", + ) + parser.add_option( + "--syslog", + default=None, + metavar="SERVER", + help="Log to syslog server. SERVER can be local socket, " + "such as /dev/log, or a UDP host:port pair.", + ) + parser.add_option( + "--legacy-syslog", + action="store_true", + help="Use the old syslog protocol instead of RFC 5424. " + "Use this if the messages produced by websockify seem abnormal.", + ) + parser.add_option( + "--file-only", + action="store_true", + help="use this to disable directory listings in web server.", + ) (opts, args) = parser.parse_args() - # Validate options. if opts.token_source and not opts.token_plugin: @@ -583,11 +717,9 @@ def websockify_init(): if opts.legacy_syslog and not opts.syslog: parser.error("You must use --syslog to use --legacy-syslog") - opts.ssl_options = select_ssl_version(opts.ssl_version) del opts.ssl_version - if opts.log_file: # Setup logging to user-specified file. opts.log_file = os.path.abspath(opts.log_file) @@ -601,9 +733,9 @@ def websockify_init(): if opts.syslog: # Determine how to connect to syslog... - if opts.syslog.count(':'): + if opts.syslog.count(":"): # User supplied a host:port pair. - syslog_host, syslog_port = opts.syslog.rsplit(':', 1) + syslog_host, syslog_port = opts.syslog.rsplit(":", 1) try: syslog_port = int(syslog_port) except ValueError: @@ -622,10 +754,12 @@ def websockify_init(): syslog_facility = WebsockifySysLogHandler.LOG_USER # Start logging to syslog. - syslog_handler = WebsockifySysLogHandler(address=syslog_dest, - facility=syslog_facility, - ident='websockify', - legacy=opts.legacy_syslog) + syslog_handler = WebsockifySysLogHandler( + address=syslog_dest, + facility=syslog_facility, + ident="websockify", + legacy=opts.legacy_syslog, + ) syslog_handler.setLevel(logging.DEBUG) syslog_handler.setFormatter(log_formatter) root = logging.getLogger() @@ -638,24 +772,23 @@ def websockify_init(): root = logging.getLogger() root.setLevel(logging.DEBUG) - # Transform to absolute path as daemon may chdir if opts.target_cfg: opts.target_cfg = os.path.abspath(opts.target_cfg) if opts.target_cfg: - opts.token_plugin = 'TokenFile' + opts.token_plugin = "TokenFile" opts.token_source = opts.target_cfg del opts.target_cfg - if sys.argv.count('--'): + if sys.argv.count("--"): opts.wrap_cmd = args[1:] else: opts.wrap_cmd = None if not websockifyserver.ssl and opts.ssl_target: - parser.error("SSL target requested and Python SSL module not loaded."); + parser.error("SSL target requested and Python SSL module not loaded.") if opts.ssl_only and not os.path.exists(opts.cert): parser.error("SSL only and %s not found" % opts.cert) @@ -677,11 +810,11 @@ def websockify_init(): parser.error("Too few arguments") arg = args.pop(0) # Parse host:port and convert ports to numbers - if arg.count(':') > 0: - opts.listen_host, opts.listen_port = arg.rsplit(':', 1) - opts.listen_host = opts.listen_host.strip('[]') + if arg.count(":") > 0: + opts.listen_host, opts.listen_port = arg.rsplit(":", 1) + opts.listen_host = opts.listen_host.strip("[]") else: - opts.listen_host, opts.listen_port = '', arg + opts.listen_host, opts.listen_port = "", arg try: opts.listen_port = int(opts.listen_port) @@ -697,9 +830,9 @@ def websockify_init(): if len(args) < 1: parser.error("Too few arguments") arg = args.pop(0) - if arg.count(':') > 0: - opts.target_host, opts.target_port = arg.rsplit(':', 1) - opts.target_host = opts.target_host.strip('[]') + if arg.count(":") > 0: + opts.target_host, opts.target_port = arg.rsplit(":", 1) + opts.target_host = opts.target_host.strip("[]") else: parser.error("Error parsing target") @@ -712,11 +845,10 @@ def websockify_init(): parser.error("Too many arguments") if opts.token_plugin is not None: - if '.' not in opts.token_plugin: - opts.token_plugin = ( - 'websockify.token_plugins.%s' % opts.token_plugin) + if "." not in opts.token_plugin: + opts.token_plugin = "websockify.token_plugins.%s" % opts.token_plugin - token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1) + token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit(".", 1) __import__(token_plugin_module) token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls) @@ -726,10 +858,10 @@ def websockify_init(): del opts.token_source if opts.auth_plugin is not None: - if '.' not in opts.auth_plugin: - opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin + if "." not in opts.auth_plugin: + opts.auth_plugin = "websockify.auth_plugins.%s" % opts.auth_plugin - auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1) + auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit(".", 1) __import__(auth_plugin_module) auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls) @@ -759,32 +891,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer): def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): # Save off proxy specific options - self.target_host = kwargs.pop('target_host', None) - self.target_port = kwargs.pop('target_port', None) - self.wrap_cmd = kwargs.pop('wrap_cmd', None) - self.wrap_mode = kwargs.pop('wrap_mode', None) - self.unix_target = kwargs.pop('unix_target', None) - self.ssl_target = kwargs.pop('ssl_target', None) - self.token_plugin = kwargs.pop('token_plugin', None) - self.auth_plugin = kwargs.pop('auth_plugin', None) - self.heartbeat = kwargs.pop('heartbeat', None) + self.target_host = kwargs.pop("target_host", None) + self.target_port = kwargs.pop("target_port", None) + self.wrap_cmd = kwargs.pop("wrap_cmd", None) + self.wrap_mode = kwargs.pop("wrap_mode", None) + self.unix_target = kwargs.pop("unix_target", None) + self.ssl_target = kwargs.pop("ssl_target", None) + self.token_plugin = kwargs.pop("token_plugin", None) + self.auth_plugin = kwargs.pop("auth_plugin", None) + self.heartbeat = kwargs.pop("heartbeat", None) self.token_plugin = None self.auth_plugin = None self.daemon = False # Server configuration - listen_host = kwargs.pop('listen_host', '') - listen_port = kwargs.pop('listen_port', None) - web = kwargs.pop('web', '') + listen_host = kwargs.pop("listen_host", "") + listen_port = kwargs.pop("listen_port", None) + web = kwargs.pop("web", "") # Configuration affecting base request handler - self.only_upgrade = not web - self.verbose = kwargs.pop('verbose', False) - record = kwargs.pop('record', '') + self.only_upgrade = not web + self.verbose = kwargs.pop("verbose", False) + record = kwargs.pop("record", "") if record: self.record = os.path.abspath(record) - self.run_once = kwargs.pop('run_once', False) + self.run_once = kwargs.pop("run_once", False) self.handler_id = 0 for arg in kwargs.keys(): @@ -795,12 +927,11 @@ class LibProxyServer(ThreadingMixIn, HTTPServer): super().__init__((listen_host, listen_port), RequestHandlerClass) - def process_request(self, request, client_address): """Override process_request to implement a counter""" self.handler_id += 1 super().process_request(request, client_address) -if __name__ == '__main__': +if __name__ == "__main__": websockify_init() diff --git a/websockify/websocketserver.py b/websockify/websocketserver.py index 4e62f2e..cd86e2d 100644 --- a/websockify/websocketserver.py +++ b/websockify/websocketserver.py @@ -1,19 +1,25 @@ #!/usr/bin/env python -''' +""" Python WebSocket server base Copyright 2011 Joel Martin Copyright 2016-2018 Pierre Ossman Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) -''' +""" import sys from http.server import BaseHTTPRequestHandler, HTTPServer -from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError +from websockify.websocket import ( + WebSocket, + WebSocketWantReadError, + WebSocketWantWriteError, +) + class HttpWebSocket(WebSocket): """Class to glue websocket and http request functionality together""" + def __init__(self, request_handler): super().__init__() @@ -62,8 +68,10 @@ class WebSocketRequestHandlerMixIn: # Checks if it is a websocket request and redirects self.do_GET = self._real_do_GET - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): + if ( + self.headers.get("upgrade") + and self.headers.get("upgrade").lower() == "websocket" + ): self.handle_upgrade() else: self.do_GET() @@ -93,18 +101,20 @@ class WebSocketRequestHandlerMixIn: def handle_websocket(self): """Handle a WebSocket connection. - + This is called when the WebSocket is ready to be used. A sub-class should perform the necessary communication here and return once done. """ pass + # Convenient ready made classes -class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, - BaseHTTPRequestHandler): + +class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler): pass + class WebSocketServer(HTTPServer): pass diff --git a/websockify/websockifyserver.py b/websockify/websockifyserver.py index 78613fb..fe8a85a 100644 --- a/websockify/websockifyserver.py +++ b/websockify/websockifyserver.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -''' +""" Python WebSocket server base with support for "wss://" encryption. Copyright 2011 Joel Martin Copyright 2016 Pierre Ossman @@ -10,35 +10,39 @@ You can make a cert/key with openssl using: openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem as taken from http://docs.python.org/dev/library/ssl.html#certificates -''' +""" import os, sys, time, errno, signal, socket, select, logging import multiprocessing from http.server import SimpleHTTPRequestHandler # Degraded functionality if these imports are missing -for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'), - ('resource', 'daemonizing is disabled')]: +for mod, msg in [ + ("ssl", "TLS/SSL/wss is disabled"), + ("resource", "daemonizing is disabled"), +]: try: globals()[mod] = __import__(mod) except ImportError: globals()[mod] = None print("WARNING: no '%s' module, %s" % (mod, msg)) -if sys.platform == 'win32': +if sys.platform == "win32": # make sockets pickle-able/inheritable import multiprocessing.reduction from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError from websockify.websocketserver import WebSocketRequestHandlerMixIn + class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): def select_subprotocol(self, protocols): # Handle old websockify clients that still specify a sub-protocol - if 'binary' in protocols: - return 'binary' + if "binary" in protocols: + return "binary" else: - return '' + return "" + # HTTP handler with WebSocket upgrade support class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler): @@ -56,6 +60,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa * run_once: Handle a single request * handler_id: A sequence number for this connection, appended to record filename """ + server_version = "WebSockify" protocol_version = "HTTP/1.1" @@ -73,7 +78,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa self.daemon = getattr(server, "daemon", False) self.record = getattr(server, "record", False) self.run_once = getattr(server, "run_once", False) - self.rec = None + self.rec = None self.handler_id = getattr(server, "handler_id", False) self.file_only = getattr(server, "file_only", False) self.traffic = getattr(server, "traffic", False) @@ -87,30 +92,33 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa super().__init__(req, addr, server) def log_message(self, format, *args): - self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args)) + self.logger.info( + "%s - - [%s] %s" + % (self.client_address[0], self.log_date_time_string(), format % args) + ) # # WebSocketRequestHandler logging/output functions # def print_traffic(self, token="."): - """ Show traffic flow mode. """ + """Show traffic flow mode.""" if self.traffic: sys.stdout.write(token) sys.stdout.flush() def msg(self, msg, *args, **kwargs): - """ Output message with handler_id prefix. """ + """Output message with handler_id prefix.""" prefix = "% 3d: " % self.handler_id self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) def vmsg(self, msg, *args, **kwargs): - """ Same as msg() but as debug. """ + """Same as msg() but as debug.""" prefix = "% 3d: " % self.handler_id self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs) def warn(self, msg, *args, **kwargs): - """ Same as msg() but as warning. """ + """Same as msg() but as warning.""" prefix = "% 3d: " % self.handler_id self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) @@ -118,19 +126,24 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa # Main WebSocketRequestHandler methods # def send_frames(self, bufs=None): - """ Encode and send WebSocket frames. Any frames already + """Encode and send WebSocket frames. Any frames already queued will be sent first. If buf is not set then only queued frames will be sent. Returns True if any frames could not be fully sent, in which case the caller should call again when - the socket is ready. """ + the socket is ready.""" - tdelta = int(time.time()*1000) - self.start_time + tdelta = int(time.time() * 1000) - self.start_time if bufs: for buf in bufs: if self.rec: # Python 3 compatible conversion - bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'") + bufstr = ( + buf.decode("latin1") + .encode("unicode_escape") + .decode("ascii") + .replace("'", "\\'") + ) self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr)) self.send_parts.append(buf) @@ -147,7 +160,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa return False def recv_frames(self): - """ Receive and decode WebSocket frames. + """Receive and decode WebSocket frames. Returns: (bufs_list, closed_string) @@ -155,7 +168,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa closed = False bufs = [] - tdelta = int(time.time()*1000) - self.start_time + tdelta = int(time.time() * 1000) - self.start_time while True: try: @@ -165,15 +178,22 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa break if buf is None: - closed = {'code': self.request.close_code, - 'reason': self.request.close_reason} + closed = { + "code": self.request.close_code, + "reason": self.request.close_reason, + } return bufs, closed self.print_traffic("}") if self.rec: # Python 3 compatible conversion - bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'") + bufstr = ( + buf.decode("latin1") + .encode("unicode_escape") + .decode("ascii") + .replace("'", "\\'") + ) self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr)) bufs.append(buf) @@ -183,16 +203,16 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa return bufs, closed - def send_close(self, code=1000, reason=''): - """ Send a WebSocket orderly close frame. """ + def send_close(self, code=1000, reason=""): + """Send a WebSocket orderly close frame.""" self.request.shutdown(socket.SHUT_RDWR, code, reason) - def send_pong(self, data=b''): - """ Send a WebSocket pong frame. """ + def send_pong(self, data=b""): + """Send a WebSocket pong frame.""" self.request.pong(data) - def send_ping(self, data=b''): - """ Send a WebSocket ping frame. """ + def send_ping(self, data=b""): + """Send a WebSocket ping frame.""" self.request.ping(data) def handle_upgrade(self): @@ -207,8 +227,8 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa self.server.ws_connection = True # Initialize per client settings self.send_parts = [] - self.recv_part = None - self.start_time = int(time.time()*1000) + self.recv_part = None + self.start_time = int(time.time() * 1000) # client_address is empty with, say, UNIX domain sockets client_addr = "" @@ -224,17 +244,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa else: self.stype = "Plain non-SSL (ws://)" - self.log_message("%s: %s WebSocket connection", client_addr, - self.stype) - if self.path != '/': + self.log_message("%s: %s WebSocket connection", client_addr, self.stype) + if self.path != "/": self.log_message("%s: Path: '%s'", client_addr, self.path) if self.record: # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) + fname = "%s.%s" % (self.record, self.handler_id) self.log_message("opening record file: %s", fname) - self.rec = open(fname, 'w+') + self.rec = open(fname, "w+") self.rec.write("var VNC_frame_data = [\n") try: @@ -261,15 +279,17 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa return super().list_directory(path) def new_websocket_client(self): - """ Do something with a WebSockets client connection. """ - raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + """Do something with a WebSockets client connection.""" + raise Exception( + "WebSocketRequestHandler.new_websocket_client() must be overloaded" + ) def validate_connection(self): - """ Ensure that the connection has a valid token, and set the target. """ + """Ensure that the connection has a valid token, and set the target.""" pass def auth_connection(self): - """ Ensure that the connection is authorized. """ + """Ensure that the connection is authorized.""" pass def do_HEAD(self): @@ -296,12 +316,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa else: super().handle() - def log_request(self, code='-', size='-'): + def log_request(self, code="-", size="-"): if self.verbose: super().log_request(code, size) -class WebSockifyServer(): +class WebSockifyServer: """ WebSockets server class. As an alternative, the standard library SocketServer can be used @@ -317,48 +337,69 @@ class WebSockifyServer(): class Terminate(Exception): pass - def __init__(self, RequestHandlerClass, listen_fd=None, - listen_host='', listen_port=None, source_is_ipv6=False, - verbose=False, cert='', key='', key_password=None, ssl_only=None, - verify_client=False, cafile=None, - daemon=False, record='', web='', web_auth=False, - file_only=False, - run_once=False, timeout=0, idle_timeout=0, traffic=False, - tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, - tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0, - unix_listen=None, unix_listen_mode=None): - + def __init__( + self, + RequestHandlerClass, + listen_fd=None, + listen_host="", + listen_port=None, + source_is_ipv6=False, + verbose=False, + cert="", + key="", + key_password=None, + ssl_only=None, + verify_client=False, + cafile=None, + daemon=False, + record="", + web="", + web_auth=False, + file_only=False, + run_once=False, + timeout=0, + idle_timeout=0, + traffic=False, + tcp_keepalive=True, + tcp_keepcnt=None, + tcp_keepidle=None, + tcp_keepintvl=None, + ssl_ciphers=None, + ssl_options=0, + unix_listen=None, + unix_listen_mode=None, + ): # settings self.RequestHandlerClass = RequestHandlerClass - self.verbose = verbose - self.listen_fd = listen_fd - self.unix_listen = unix_listen - self.unix_listen_mode = unix_listen_mode - self.listen_host = listen_host - self.listen_port = listen_port - self.prefer_ipv6 = source_is_ipv6 - self.ssl_only = ssl_only - self.ssl_ciphers = ssl_ciphers - self.ssl_options = ssl_options - self.verify_client = verify_client - self.daemon = daemon - self.run_once = run_once - self.timeout = timeout - self.idle_timeout = idle_timeout - self.traffic = traffic - self.file_only = file_only - self.web_auth = web_auth + self.verbose = verbose + self.listen_fd = listen_fd + self.unix_listen = unix_listen + self.unix_listen_mode = unix_listen_mode + self.listen_host = listen_host + self.listen_port = listen_port + self.prefer_ipv6 = source_is_ipv6 + self.ssl_only = ssl_only + self.ssl_ciphers = ssl_ciphers + self.ssl_options = ssl_options + self.verify_client = verify_client + self.daemon = daemon + self.run_once = run_once + self.timeout = timeout + self.idle_timeout = idle_timeout + self.traffic = traffic + self.file_only = file_only + self.web_auth = web_auth - self.launch_time = time.time() - self.ws_connection = False - self.handler_id = 1 - self.terminating = False + self.launch_time = time.time() + self.ws_connection = False + self.handler_id = 1 + self.terminating = False - self.logger = self.get_logger() - self.tcp_keepalive = tcp_keepalive - self.tcp_keepcnt = tcp_keepcnt - self.tcp_keepidle = tcp_keepidle - self.tcp_keepintvl = tcp_keepintvl + self.logger = self.get_logger() + self.tcp_keepalive = tcp_keepalive + self.tcp_keepcnt = tcp_keepcnt + self.tcp_keepidle = tcp_keepidle + self.tcp_keepintvl = tcp_keepintvl # keyfile path must be None if not specified self.key = None @@ -366,7 +407,7 @@ class WebSockifyServer(): # Make paths settings absolute self.cert = os.path.abspath(cert) - self.web = self.record = self.cafile = '' + self.web = self.record = self.cafile = "" if key: self.key = os.path.abspath(key) if web: @@ -393,11 +434,12 @@ class WebSockifyServer(): elif self.unix_listen != None: self.msg(" - Listen on unix socket %s", self.unix_listen) else: - self.msg(" - Listen on %s:%s", - self.listen_host, self.listen_port) + self.msg(" - Listen on %s:%s", self.listen_host, self.listen_port) if self.web: if self.file_only: - self.msg(" - Web server (no directory listings). Web root: %s", self.web) + self.msg( + " - Web server (no directory listings). Web root: %s", self.web + ) else: self.msg(" - Web server. Web root: %s", self.web) if ssl: @@ -420,34 +462,45 @@ class WebSockifyServer(): @staticmethod def get_logger(): - return logging.getLogger("%s.%s" % ( - WebSockifyServer.log_prefix, - WebSockifyServer.__class__.__name__)) + return logging.getLogger( + "%s.%s" % (WebSockifyServer.log_prefix, WebSockifyServer.__class__.__name__) + ) @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False, - unix_socket=None, unix_socket_mode=None, unix_socket_listen=False, - use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None, - tcp_keepidle=None, tcp_keepintvl=None): - """ Resolve a host (and optional port) to an IPv4 or IPv6 + def socket( + host, + port=None, + connect=False, + prefer_ipv6=False, + unix_socket=None, + unix_socket_mode=None, + unix_socket_listen=False, + use_ssl=False, + tcp_keepalive=True, + tcp_keepcnt=None, + tcp_keepidle=None, + tcp_keepintvl=None, + ): + """Resolve a host (and optional port) to an IPv4 or IPv6 address. Create a socket. Bind to it if listen is set, otherwise connect to it. Return the socket. """ flags = 0 - if host == '': + if host == "": host = None if connect and not (port or unix_socket): raise Exception("Connect mode requires a port") if use_ssl and not ssl: - raise Exception("SSL socket requested but Python SSL module not loaded."); + raise Exception("SSL socket requested but Python SSL module not loaded.") if not connect and use_ssl: raise Exception("SSL only supported in connect mode (for now)") if not connect: flags = flags | socket.AI_PASSIVE if not unix_socket: - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, - socket.IPPROTO_TCP, flags) + addrs = socket.getaddrinfo( + host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags + ) if not addrs: raise Exception("Could not resolve host '%s'" % host) addrs.sort(key=lambda x: x[0]) @@ -455,17 +508,14 @@ class WebSockifyServer(): addrs.reverse() sock = socket.socket(addrs[0][0], addrs[0][1]) - if tcp_keepalive: + if tcp_keepalive: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if tcp_keepcnt: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, - tcp_keepcnt) + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, tcp_keepcnt) if tcp_keepidle: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, - tcp_keepidle) + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, tcp_keepidle) if tcp_keepintvl: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, - tcp_keepintvl) + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl) if connect: sock.connect(addrs[0][4]) @@ -497,8 +547,7 @@ class WebSockifyServer(): return sock @staticmethod - def daemonize(keepfd=None, chdir='/'): - + def daemonize(keepfd=None, chdir="/"): if keepfd is None: keepfd = [] @@ -506,14 +555,16 @@ class WebSockifyServer(): if chdir: os.chdir(chdir) else: - os.chdir('/') + os.chdir("/") os.setgid(os.getgid()) # relinquish elevations os.setuid(os.getuid()) # relinquish elevations # Double fork to daemonize - if os.fork() > 0: os._exit(0) # Parent exits - os.setsid() # Obtain new process group - if os.fork() > 0: os._exit(0) # Parent exits + if os.fork() > 0: + os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: + os._exit(0) # Parent exits # Signal handling signal.signal(signal.SIGTERM, signal.SIG_IGN) @@ -521,14 +572,16 @@ class WebSockifyServer(): # Close open files maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if maxfd == resource.RLIM_INFINITY: maxfd = 256 + if maxfd == resource.RLIM_INFINITY: + maxfd = 256 for fd in reversed(range(maxfd)): try: if fd not in keepfd: os.close(fd) except OSError: _, exc, _ = sys.exc_info() - if exc.errno != errno.EBADF: raise + if exc.errno != errno.EBADF: + raise # Redirect I/O to /dev/null os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) @@ -557,7 +610,7 @@ class WebSockifyServer(): # Peek, but do not read the data so that we have a opportunity # to SSL wrap the socket first handshake = sock.recv(1024, socket.MSG_PEEK) - #self.msg("Handshake [%s]" % handshake) + # self.msg("Handshake [%s]" % handshake) if not handshake: raise self.EClose("") @@ -567,8 +620,7 @@ class WebSockifyServer(): if not ssl: raise self.EClose("SSL connection but no 'ssl' module") if not os.path.exists(self.cert): - raise self.EClose("SSL connection but '%s' not found" - % self.cert) + raise self.EClose("SSL connection but '%s' not found" % self.cert) retsock = None try: # create new-style SSL wrapping for extended features @@ -576,16 +628,16 @@ class WebSockifyServer(): if self.ssl_ciphers is not None: context.set_ciphers(self.ssl_ciphers) context.options = self.ssl_options - context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password) + context.load_cert_chain( + certfile=self.cert, keyfile=self.key, password=self.key_password + ) if self.verify_client: context.verify_mode = ssl.CERT_REQUIRED if self.cafile: context.load_verify_locations(cafile=self.cafile) else: context.set_default_verify_paths() - retsock = context.wrap_socket( - sock, - server_side=True) + retsock = context.wrap_socket(sock, server_side=True) except ssl.SSLError: _, x, _ = sys.exc_info() if x.args[0] == ssl.SSL_ERROR_EOF: @@ -618,28 +670,27 @@ class WebSockifyServer(): # def msg(self, *args, **kwargs): - """ Output message as info """ + """Output message as info""" self.logger.log(logging.INFO, *args, **kwargs) def vmsg(self, *args, **kwargs): - """ Same as msg() but as debug. """ + """Same as msg() but as debug.""" self.logger.log(logging.DEBUG, *args, **kwargs) def warn(self, *args, **kwargs): - """ Same as msg() but as warning. """ + """Same as msg() but as warning.""" self.logger.log(logging.WARN, *args, **kwargs) - # # Events that can/should be overridden in sub-classes # def started(self): - """ Called after WebSockets startup """ + """Called after WebSockets startup""" self.vmsg("WebSockets server started") def poll(self): - """ Run periodically while waiting for connections. """ - #self.vmsg("Running poll()") + """Run periodically while waiting for connections.""" + # self.vmsg("Running poll()") pass def terminate(self): @@ -661,7 +712,7 @@ class WebSockifyServer(): while result[0]: self.vmsg("Reaped child process %s" % result[0]) result = os.waitpid(-1, os.WNOHANG) - except (OSError): + except OSError: pass def do_SIGINT(self, sig, stack): @@ -675,7 +726,7 @@ class WebSockifyServer(): self.terminate() def top_new_client(self, startsock, address): - """ Do something with a WebSockets client connection. """ + """Do something with a WebSockets client connection.""" # handler process client = None try: @@ -693,7 +744,6 @@ class WebSockifyServer(): self.msg("handler exception: %s" % str(exc)) self.vmsg("exception", exc_info=True) finally: - if client and client != startsock: # Close the SSL wrapped socket # Original socket closed by caller @@ -721,19 +771,27 @@ class WebSockifyServer(): try: if self.listen_fd != None: - lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM) + lsock = socket.fromfd( + self.listen_fd, socket.AF_INET, socket.SOCK_STREAM + ) elif self.unix_listen != None: - lsock = self.socket(host=None, - unix_socket=self.unix_listen, - unix_socket_mode=self.unix_listen_mode, - unix_socket_listen=True) + lsock = self.socket( + host=None, + unix_socket=self.unix_listen, + unix_socket_mode=self.unix_listen_mode, + unix_socket_listen=True, + ) else: - lsock = self.socket(self.listen_host, self.listen_port, False, - self.prefer_ipv6, - tcp_keepalive=self.tcp_keepalive, - tcp_keepcnt=self.tcp_keepcnt, - tcp_keepidle=self.tcp_keepidle, - tcp_keepintvl=self.tcp_keepintvl) + lsock = self.socket( + self.listen_host, + self.listen_port, + False, + self.prefer_ipv6, + tcp_keepalive=self.tcp_keepalive, + tcp_keepcnt=self.tcp_keepcnt, + tcp_keepidle=self.tcp_keepidle, + tcp_keepintvl=self.tcp_keepintvl, + ) except OSError as e: self.msg("Openening socket failed: %s", str(e)) self.vmsg("exception", exc_info=True) @@ -751,13 +809,13 @@ class WebSockifyServer(): signal.SIGINT: signal.getsignal(signal.SIGINT), signal.SIGTERM: signal.getsignal(signal.SIGTERM), } - if getattr(signal, 'SIGCHLD', None) is not None: + if getattr(signal, "SIGCHLD", None) is not None: original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD) signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGTERM, self.do_SIGTERM) # make sure that _cleanup is called when children die # by calling active_children on SIGCHLD - if getattr(signal, 'SIGCHLD', None) is not None: + if getattr(signal, "SIGCHLD", None) is not None: signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) last_active_time = self.launch_time @@ -774,8 +832,7 @@ class WebSockifyServer(): time_elapsed = time.time() - self.launch_time if self.timeout and time_elapsed > self.timeout: - self.msg('listener exit due to --timeout %s' - % self.timeout) + self.msg("listener exit due to --timeout %s" % self.timeout) break if self.idle_timeout: @@ -787,8 +844,10 @@ class WebSockifyServer(): last_active_time = time.time() if idle_time > self.idle_timeout and child_count == 0: - self.msg('listener exit due to --idle-timeout %s' - % self.idle_timeout) + self.msg( + "listener exit due to --idle-timeout %s" + % self.idle_timeout + ) break try: @@ -799,16 +858,16 @@ class WebSockifyServer(): startsock, address = lsock.accept() # Unix Socket will not report address (empty string), but address[0] is logged a bunch if self.unix_listen != None: - address = [ self.unix_listen ] + address = [self.unix_listen] else: continue except self.Terminate: raise except Exception: _, exc, _ = sys.exc_info() - if hasattr(exc, 'errno'): + if hasattr(exc, "errno"): err = exc.errno - elif hasattr(exc, 'args'): + elif hasattr(exc, "args"): err = exc.args[0] else: err = exc[0] @@ -821,15 +880,14 @@ class WebSockifyServer(): if self.run_once: # Run in same process if run_once self.top_new_client(startsock, address) - if self.ws_connection : - self.msg('%s: exiting due to --run-once' - % address[0]) + if self.ws_connection: + self.msg("%s: exiting due to --run-once" % address[0]) break else: - self.vmsg('%s: new handler Process' % address[0]) + self.vmsg("%s: new handler Process" % address[0]) p = multiprocessing.Process( - target=self.top_new_client, - args=(startsock, address)) + target=self.top_new_client, args=(startsock, address) + ) p.start() # child will not return @@ -857,12 +915,11 @@ class WebSockifyServer(): startsock.close() finally: # Close listen port - self.vmsg("Closing socket listening at %s:%s", - self.listen_host, self.listen_port) + self.vmsg( + "Closing socket listening at %s:%s", self.listen_host, self.listen_port + ) lsock.close() # Restore signals for sig, func in original_signals.items(): signal.signal(sig, func) - -