From c6557d3d81552abb1c2c937255d116d782022c15 Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Tue, 28 Apr 2015 16:17:47 -0400 Subject: [PATCH 1/5] Introduce Auth Plugins Auth plugins provide a generic interface for authenticating requests. The plugin name is specified using the '--auth-plugin' option, and may either be the name of a class from `websockify.auth_plugins`, or a fully qualified python path to the auth plugin class (see below). An optional plugin parameter can be specified using the '--auth-source' option (a value of `None` will be used if no '--auth-source' option is specified). Auth plugins should inherit from `websockify.auth_plugins.BasePlugin`, and should implement the `authenticate(headers, target_host, target_port)` method. The value of the '--auth-source' option is available as `self.source`. One plugin is currently included: `ExpectOrigin`. The `ExpectOrigin` plugin checks that the 'Origin' header is an expected value. The list of acceptable origins is passed using the plugin source, as a space-separated list. --- websockify/auth_plugins.py | 33 +++++++++++++++++++++++++++++++++ websockify/websocketproxy.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 websockify/auth_plugins.py diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py new file mode 100644 index 0000000..dd1b8de --- /dev/null +++ b/websockify/auth_plugins.py @@ -0,0 +1,33 @@ +class BasePlugin(object): + def __init__(self, src=None): + self.source = src + + def authenticate(self, headers, target_host, target_port): + pass + + +class AuthenticationError(Exception): + pass + + +class InvalidOriginError(AuthenticationError): + def __init__(self, expected, actual): + self.expected_origin = expected + self.actual_origin = actual + + super(InvalidOriginError, self).__init__( + "Invalid Origin Header: Expected one of " + "%s, got '%s'" % (expected, actual)) + + +class ExpectOrigin(object): + def __init__(self, src=None): + if src is None: + self.source = [] + else: + self.source = src.split() + + def authenticate(self, headers, target_host, target_port): + origin = headers.getheader('Origin', None) + if origin is None or origin not in self.source: + raise InvalidOriginError(expected=self.source, actual=origin) diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 94454c0..e1fe012 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -47,6 +47,11 @@ Traffic Legend: if self.server.token_plugin: (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) + if self.server.auth_plugin: + self.server.auth_plugin.authenticate( + headers=self.headers, target_host=self.server.target_host, + target_port=self.server.target_port) + # Connect to the target if self.server.wrap_cmd: msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) @@ -197,6 +202,9 @@ class WebSocketProxy(websocket.WebSocketServer): token_plugin = kwargs.pop('token_plugin', None) token_source = kwargs.pop('token_source', None) + auth_plugin = kwargs.pop('auth_plugin', None) + auth_source = kwargs.pop('auth_source', None) + if token_plugin is not None: if '.' not in token_plugin: token_plugin = 'websockify.token_plugins.%s' % token_plugin @@ -210,6 +218,19 @@ class WebSocketProxy(websocket.WebSocketServer): else: self.token_plugin = None + if auth_plugin is not None: + if '.' not in auth_plugin: + auth_plugin = 'websockify.auth_plugins.%s' % auth_plugin + + auth_plugin_module, auth_plugin_cls = auth_plugin.rsplit('.', 1) + + __import__(auth_plugin_module) + auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls) + + self.auth_plugin = auth_plugin_cls(auth_source) + else: + self.auth_plugin = None + # Last 3 timestamps command was run self.wrap_times = [0, 0, 0] @@ -381,6 +402,12 @@ def websockify_init(): parser.add_option("--token-source", default=None, metavar="ARG", help="an argument to be passed to the token plugin" "on instantiation") + parser.add_option("--auth-plugin", default=None, metavar="PLUGIN", + help="use the given Python class to determine if " + "a connection is allowed") + parser.add_option("--auth-source", default=None, metavar="ARG", + help="an argument to be passed to the auth plugin" + "on instantiation") parser.add_option("--auto-pong", action="store_true", help="Automatically respond to ping frames with a pong") parser.add_option("--heartbeat", type=int, default=0, @@ -394,6 +421,10 @@ def websockify_init(): if opts.token_source and not opts.token_plugin: parser.error("You must use --token-plugin to use --token-source") + if opts.auth_source and not opts.auth_plugin: + parser.error("You must use --auth-plugin to use --auth-source") + + # Transform to absolute path as daemon may chdir if opts.target_cfg: opts.target_cfg = os.path.abspath(opts.target_cfg) @@ -471,9 +502,11 @@ class LibProxyServer(ForkingMixIn, HTTPServer): self.ssl_target = kwargs.pop('ssl_target', None) self.token_plugin = kwargs.pop('token_plugin', None) self.token_source = kwargs.pop('token_source', None) + self.auth_plugin = kwargs.pop('auth_plugin', None) self.heartbeat = kwargs.pop('heartbeat', None) self.token_plugin = None + self.auth_plugin = None self.daemon = False # Server configuration From aaefff0bf0544fead4ecc92646ea4f9a1afc0e1e Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Tue, 28 Apr 2015 17:01:01 -0400 Subject: [PATCH 2/5] Fix bug in ReadOnlyTokenFile ReadOnlyTokenFile didn't initialize `self._targets` to `None`, causing issues. --- websockify/token_plugins.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index fcb6080..b75ef00 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -12,6 +12,10 @@ class ReadOnlyTokenFile(BasePlugin): # source is a token file with lines like # token: host:port # or a directory of such files + def __init__(self, *args, **kwargs): + super(ReadOnlyTokenFile, self).__init__(*args, **kwargs) + self._targets = None + def _load_targets(self): if os.path.isdir(self.source): cfg_files = [os.path.join(self.source, f) for From f6e9fbe5cf096d3bbab1ce64f765ee934a0adefe Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Tue, 28 Apr 2015 17:08:36 -0400 Subject: [PATCH 3/5] Process plugin parameters in main Previously, we just passed the values of '--*-plugin' and '--*-source' directly to `LibProxyServer` and `WebSocketProxy`, which handled turning that into an instance of the plugin class. Now, that's done in main, and the classes receive an instance directly. --- websockify/websocketproxy.py | 59 +++++++++++++++++------------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index e1fe012..b21f539 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -199,37 +199,8 @@ class WebSocketProxy(websocket.WebSocketServer): self.ssl_target = kwargs.pop('ssl_target', None) self.heartbeat = kwargs.pop('heartbeat', None) - token_plugin = kwargs.pop('token_plugin', None) - token_source = kwargs.pop('token_source', None) - - auth_plugin = kwargs.pop('auth_plugin', None) - auth_source = kwargs.pop('auth_source', None) - - if token_plugin is not None: - if '.' not in token_plugin: - token_plugin = 'websockify.token_plugins.%s' % token_plugin - - token_plugin_module, token_plugin_cls = token_plugin.rsplit('.', 1) - - __import__(token_plugin_module) - token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls) - - self.token_plugin = token_plugin_cls(token_source) - else: - self.token_plugin = None - - if auth_plugin is not None: - if '.' not in auth_plugin: - auth_plugin = 'websockify.auth_plugins.%s' % auth_plugin - - auth_plugin_module, auth_plugin_cls = auth_plugin.rsplit('.', 1) - - __import__(auth_plugin_module) - auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls) - - self.auth_plugin = auth_plugin_cls(auth_source) - else: - self.auth_plugin = None + self.token_plugin = kwargs.pop('token_plugin', None) + self.auth_plugin = kwargs.pop('auth_plugin', None) # Last 3 timestamps command was run self.wrap_times = [0, 0, 0] @@ -473,6 +444,31 @@ def websockify_init(): try: opts.target_port = int(opts.target_port) except: parser.error("Error parsing target port") + if opts.token_plugin is not None: + if '.' not in opts.token_plugin: + opts.token_plugin = ( + 'websockify.token_plugins.%s' % opts.token_plugin) + + token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1) + + __import__(token_plugin_module) + token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls) + + opts.token_plugin = token_plugin_cls(opts.token_source) + del opts.token_source + + if opts.auth_plugin is not None: + if '.' not in opts.auth_plugin: + opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin + + auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1) + + __import__(auth_plugin_module) + auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls) + + opts.auth_plugin = auth_plugin_cls(opts.auth_source) + del opts.auth_source + # Create and start the WebSockets proxy libserver = opts.libserver del opts.libserver @@ -501,7 +497,6 @@ class LibProxyServer(ForkingMixIn, HTTPServer): self.unix_target = kwargs.pop('unix_target', None) self.ssl_target = kwargs.pop('ssl_target', None) self.token_plugin = kwargs.pop('token_plugin', None) - self.token_source = kwargs.pop('token_source', None) self.auth_plugin = kwargs.pop('auth_plugin', None) self.heartbeat = kwargs.pop('heartbeat', None) From 52f6830852216fc80ab2505b0ca0f5bbcd450d5a Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Wed, 6 May 2015 13:49:13 -0400 Subject: [PATCH 4/5] Update Tests and Test Plugins This commit updates the unit tests to work with the current code and adds in tests for the auth and token plugin functionality. --- .gitignore | 2 + tests/test_websocket.py | 404 ++++++++++++++++++++++------------- tests/test_websocketproxy.py | 181 ++++++++-------- tests/tox.ini => tox.ini | 5 +- websockify/websocket.py | 2 +- 5 files changed, 353 insertions(+), 241 deletions(-) rename tests/tox.ini => tox.ini (78%) diff --git a/.gitignore b/.gitignore index 3bf91dd..3ac5ff1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ other/node_modules .pydevproject target.cfg target.cfg.d +.tox +*.egg-info diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c7a106f..acd7699 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -26,201 +26,303 @@ import stubout import sys import tempfile import unittest -from ssl import SSLError -from websockify import websocket as websocket -from SimpleHTTPServer import SimpleHTTPRequestHandler +import socket +import signal +from websockify import websocket + +try: + from SimpleHTTPServer import SimpleHTTPRequestHandler +except ImportError: + from http.server import SimpleHTTPRequestHandler + +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO -class MockConnection(object): - def __init__(self, path): - self.path = path - - def makefile(self, mode='r', bufsize=-1): - return open(self.path, mode, bufsize) -class WebSocketTestCase(unittest.TestCase): +def raise_oserror(*args, **kwargs): + raise OSError('fake error') - def _init_logger(self, tmpdir): - name = 'websocket-unittest' - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.propagate = True - filename = "%s.log" % (name) - handler = logging.FileHandler(filename) - handler.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(handler) +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') + + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + + return res + + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) + + +class WebSocketRequestHandlerTestCase(unittest.TestCase): def setUp(self): - """Called automatically before each test.""" - super(WebSocketTestCase, self).setUp() + super(WebSocketRequestHandlerTestCase, self).setUp() self.stubs = stubout.StubOutForTesting() - # Temporary dir for test data - self.tmpdir = tempfile.mkdtemp() - # Put log somewhere persistent - self._init_logger('./') + self.tmpdir = tempfile.mkdtemp('-websockify-tests') # Mock this out cause it screws tests up self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - self.server = self._get_websockserver(daemon=True, - ssl_only=False) - self.soc = self.server.socket('localhost') + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args, **kwargs: None) def tearDown(self): """Called automatically after each test.""" self.stubs.UnsetAll() - shutil.rmtree(self.tmpdir) - super(WebSocketTestCase, self).tearDown() + os.rmdir(self.tmpdir) + super(WebSocketRequestHandlerTestCase, self).tearDown() - def _get_websockserver(self, **kwargs): - return websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key=self.tmpdir, - web=self.tmpdir, - record=self.tmpdir, - **kwargs) + def _get_server(self, handler_class=websocket.WebSocketRequestHandler, + **kwargs): + web = kwargs.pop('web', self.tmpdir) + return websocket.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=web, + record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, + **kwargs) - def _mock_os_open_oserror(self, file, flags): - raise OSError('') + def test_normal_get_with_only_upgrade_returns_error(self): + server = self._get_server(web=None) + handler = websocket.WebSocketRequestHandler( + FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) - def _mock_os_close_oserror(self, fd): - raise OSError('') + def fake_send_response(self, code, message=None): + self.last_code = code - def _mock_os_close_oserror_EBADF(self, fd): - raise OSError(errno.EBADF, '') + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) - def _mock_socket(self, *args, **kwargs): - return self.soc + handler.do_GET() + self.assertEqual(handler.last_code, 405) - def _mock_select(self, rlist, wlist, xlist, timeout=None): - return '_mock_select' + def test_list_dir_with_file_only_returns_error(self): + server = self._get_server(file_only=True) + handler = websocket.WebSocketRequestHandler( + FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) - def _mock_select_exception(self, rlist, wlist, xlist, timeout=None): - raise Exception + def fake_send_response(self, code, message=None): + self.last_code = code - def _mock_select_keyboardinterrupt(self, rlist, wlist, - xlist, timeout=None): - raise KeyboardInterrupt + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) - def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None): - sys.exit() + handler.path = '/' + handler.do_GET() + self.assertEqual(handler.last_code, 404) - def test_daemonize_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: None) + +class WebSocketServerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketServerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketServerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocket.WebSocketRequestHandler, + **kwargs): + return websocket.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=self.tmpdir, + record=self.tmpdir, **kwargs) + + def test_daemonize_raises_error_while_closing_fds(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(signal, 'signal', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', self._mock_os_close_oserror) - self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + self.stubs.Set(os, 'close', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - def test_daemonize_EBADF_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: None) + def test_daemonize_ignores_ebadf_error_while_closing_fds(self): + def raise_oserror_ebadf(fd): + raise OSError(errno.EBADF, 'fake error') + + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) - self.stubs.Set(os, 'open', self._mock_os_open_oserror) - self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror_ebadf) + self.stubs.Set(os, 'open', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - def test_decode_hybi(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, - base64=True) + def test_handshake_fails_on_not_ready(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - def test_do_websocket_handshake(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - soc.scheme = 'scheme' - headers = {'Sec-WebSocket-Protocol': 'binary', - 'Sec-WebSocket-Version': '7', - 'Sec-WebSocket-Key': 'foo'} - soc.do_websocket_handshake(headers, '127.0.0.1') + def fake_select(rlist, wlist, xlist, timeout=None): + return ([], [], []) - def test_do_handshake(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - self.stubs.Set(select, 'select', self._mock_select) - self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') - self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + FakeSocket(), '127.0.0.1') - def test_do_handshake_ssl_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) + def test_empty_handshake_fails(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - def _mock_wrap_socket(*args, **kwargs): - from ssl import SSLError - raise SSLError('unit test exception') + sock = FakeSocket('') - self.stubs.Set(select, 'select', self._mock_select) - self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16') - self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket) - self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) - def test_fallback_SIGCHILD(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - soc.fallback_SIGCHLD(None, None) + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') - def test_start_server_Exception(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + def test_handshake_policy_request(self): + # TODO(directxman12): implement + pass + + def test_handshake_ssl_only_without_ssl_raises_error(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_no_ssl(self): + class FakeHandler(object): + CALLED = False + def __init__(self, *args, **kwargs): + type(self).CALLED = True + + FakeHandler.CALLED = False + + server = self._get_server( + handler_class=FakeHandler, daemon=True, + ssl_only=0, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) + self.assertTrue(FakeHandler.CALLED, True) + + def test_do_handshake_ssl(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_ssl_raises_error(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_cert_raises_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, + cert='afdsfasdafdsafdsafdsafdas') + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_ssl_error_eof_raises_close_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + def fake_wrap_socket(*args, **kwargs): + raise ssl.SSLError(ssl.SSL_ERROR_EOF) + + self.stubs.Set(select, 'select', fake_select) + self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_fallback_sigchld_handler(self): + # TODO(directxman12): implement this + pass + + def test_start_server_error(self): + server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise Exception("fake error") + + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_exception) - self.assertEqual(None, soc.start_server()) + self.stubs.Set(select, 'select', fake_select) + server.start_server() - def test_start_server_KeyboardInterrupt(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + def test_start_server_keyboardinterrupt(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise KeyboardInterrupt + + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt) - self.assertEqual(None, soc.start_server()) + self.stubs.Set(select, 'select', fake_select) + server.start_server() def test_start_server_systemexit(self): - websocket.ssl = None - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + sys.exit() + + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_systemexit) - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1, - verbose=True) - self.assertEqual(None, soc.start_server()) + self.stubs.Set(select, 'select', fake_select) + server.start_server() - def test_WSRequestHandle_do_GET_nofile(self): - request = 'GET /tmp.txt HTTP/0.9' - with tempfile.NamedTemporaryFile() as test_file: - test_file.write(request) - test_file.flush() - test_file.seek(0) - con = MockConnection(test_file.name) - soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True) - soc.path = '' - soc.headers = {'upgrade': ''} - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - lambda *args: None) - soc.do_GET() - self.assertEqual(404, soc.last_code) - - def test_WSRequestHandle_do_GET_hidden_resource(self): - request = 'GET /tmp.txt HTTP/0.9' - with tempfile.NamedTemporaryFile() as test_file: - test_file.write(request) - test_file.flush() - test_file.seek(0) - con = MockConnection(test_file.name) - soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True) - soc.path = test_file.name + '?' - soc.headers = {'upgrade': ''} - soc.webroot = 'no match startswith' - self.stubs.Set(SimpleHTTPRequestHandler, - 'send_response', - lambda *args: None) - soc.do_GET() - self.assertEqual(403, soc.last_code) - - def testsocket_set_keepalive_options(self): + def test_socket_set_keepalive_options(self): keepcnt = 12 keepidle = 34 keepintvl = 56 - sock = self.server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) @@ -229,11 +331,11 @@ class WebSocketTestCase(unittest.TestCase): self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl) - sock = self.server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index cf940ae..8103ef6 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -1,6 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright(c)2013 NTT corp. All Rights Reserved. +# Copyright(c) 2015 Red Hat, Inc All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -15,113 +15,122 @@ # under the License. """ Unit tests for websocketproxy """ -import os -import logging -import select -import shutil -import stubout -import subprocess -import tempfile -import time + import unittest +import unittest +import socket +import stubout + +from websockify import websocket from websockify import websocketproxy +from websockify import token_plugins +from websockify import auth_plugins + +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO -class MockSocket(object): - def __init__(*args, **kwargs): +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') + + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + + return res + + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) + + +class FakeServer(object): + class EClose(Exception): pass - def shutdown(*args): - pass - - def close(*args): - pass - - -class WebSocketProxyTest(unittest.TestCase): - - def _init_logger(self, tmpdir): - name = 'websocket-unittest' - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.propagate = True - filename = "%s.log" % (name) - handler = logging.FileHandler(filename) - handler.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(handler) + def __init__(self): + self.token_plugin = None + self.auth_plugin = None + self.wrap_cmd = None + self.ssl_target = None + self.unix_target = None +class ProxyRequestHandlerTestCase(unittest.TestCase): def setUp(self): - """Called automatically before each test.""" - super(WebSocketProxyTest, self).setUp() - self.soc = '' + super(ProxyRequestHandlerTestCase, self).setUp() self.stubs = stubout.StubOutForTesting() - # Temporary dir for test data - self.tmpdir = tempfile.mkdtemp() - # Put log somewhere persistent - self._init_logger('./') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + self.handler = websocketproxy.ProxyRequestHandler( + FakeSocket(''), "127.0.0.1", FakeServer()) + self.handler.path = "https://localhost:6080/websockify?token=blah" + self.handler.headers = None + self.stubs.Set(websocket.WebSocketServer, 'socket', + staticmethod(lambda *args, **kwargs: None)) def tearDown(self): - """Called automatically after each test.""" self.stubs.UnsetAll() - shutil.rmtree(self.tmpdir) - super(WebSocketProxyTest, self).tearDown() + super(ProxyRequestHandlerTestCase, self).tearDown() - def _get_websockproxy(self, **kwargs): - return websocketproxy.WebSocketProxy(key=self.tmpdir, - web=self.tmpdir, - record=self.tmpdir, - **kwargs) + def test_get_target(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return ("some host", "some port") - def test_run_wrap_cmd(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" + host, port = self.handler.get_target( + TestPlugin(None), self.handler.path) - def mock_Popen(*args, **kwargs): - return '_mock_cmd' + self.assertEqual(host, "some host") + self.assertEqual(port, "some port") - self.stubs.Set(subprocess, 'Popen', mock_Popen) - web_socket_proxy.run_wrap_cmd() - self.assertEquals(web_socket_proxy.spawn_message, True) + def test_get_target_raises_error_on_unknown_token(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return None - def test_started(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["spawn_message"] = False - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" + self.assertRaises(FakeServer.EClose, self.handler.get_target, + TestPlugin(None), "https://localhost:6080/websockify?token=blah") - def mock_run_wrap_cmd(*args, **kwargs): - web_socket_proxy.__dict__["spawn_message"] = True + def test_token_plugin(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return (self.source + token).split(',') - self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd) - web_socket_proxy.started() - self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True) + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + lambda *args, **kwargs: None) - def test_poll(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" - web_socket_proxy.__dict__["wrap_mode"] = "respawn" - web_socket_proxy.__dict__["wrap_times"] = [99999999] - web_socket_proxy.__dict__["spawn_message"] = True - web_socket_proxy.__dict__["cmd"] = None - self.stubs.Set(time, 'time', lambda: 100000000.000) - web_socket_proxy.poll() - self.assertEquals(web_socket_proxy.spawn_message, False) + self.handler.server.token_plugin = TestPlugin("somehost,") + self.handler.new_websocket_client() - def test_new_client(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["verbose"] = "verbose" - web_socket_proxy.__dict__["daemon"] = None - web_socket_proxy.__dict__["client"] = "client" + self.assertEqual(self.handler.server.target_host, "somehost") + self.assertEqual(self.handler.server.target_port, "blah") - self.stubs.Set(web_socket_proxy, 'socket', MockSocket) + def test_auth_plugin(self): + class TestPlugin(auth_plugins.BasePlugin): + def authenticate(self, headers, target_host, target_port): + if target_host == self.source: + raise auth_plugins.AuthenticationError("some error") - def mock_select(*args, **kwargs): - ins = None - outs = None - excepts = "excepts" - return ins, outs, excepts + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + staticmethod(lambda *args, **kwargs: None)) + + self.handler.server.auth_plugin = TestPlugin("somehost") + self.handler.server.target_host = "somehost" + self.handler.server.target_port = "someport" + + self.assertRaises(auth_plugins.AuthenticationError, + self.handler.new_websocket_client) + + self.handler.server.target_host = "someotherhost" + self.handler.new_websocket_client() - self.stubs.Set(select, 'select', mock_select) - self.assertRaises(Exception, web_socket_proxy.new_websocket_client) diff --git a/tests/tox.ini b/tox.ini similarity index 78% rename from tests/tox.ini rename to tox.ini index 098e89c..012d349 100644 --- a/tests/tox.ini +++ b/tox.ini @@ -4,8 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py24,py25,py26,py27,py30 -setupdir = ../ +envlist = py24,py26,py27,py33,py34 [testenv] commands = nosetests {posargs} @@ -13,7 +12,7 @@ deps = mox nose -# At some point we should enable this since tox epdctes it to exist but +# At some point we should enable this since tox expects it to exist but # the code will need pep8ising first. #[testenv:pep8] #commands = flake8 diff --git a/websockify/websocket.py b/websockify/websocket.py index 20305b8..727413a 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -790,7 +790,7 @@ class WebSocketServer(object): handshake = sock.recv(1024, socket.MSG_PEEK) #self.msg("Handshake [%s]" % handshake) - if handshake == "": + if not handshake: raise self.EClose("ignoring empty handshake") elif handshake.startswith(s2b("")): From afb5d2801fd477411dad13bf974e36631bc1c5b7 Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Wed, 6 May 2015 14:29:51 -0400 Subject: [PATCH 5/5] Enable Travis This commit enables running the unit tests on Travis for Python 2.6, 2.7, 3.3, and 3.4. Note that Travis does not support Python 2.4, so we cannot test there. --- .travis.yml | 10 ++++++++++ test-requirements.txt | 2 ++ tox.ini | 6 ++---- 3 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 .travis.yml create mode 100644 test-requirements.txt diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..4aced3a --- /dev/null +++ b/.travis.yml @@ -0,0 +1,10 @@ +language: python +python: + - 2.6 + - 2.7 + - 3.3 + - 3.4 + +install: pip install -r test-requirements.txt + +script: python setup.py nosetests --verbosity=3 diff --git a/test-requirements.txt b/test-requirements.txt new file mode 100644 index 0000000..93207c1 --- /dev/null +++ b/test-requirements.txt @@ -0,0 +1,2 @@ +mox +nose diff --git a/tox.ini b/tox.ini index 012d349..79f7201 100644 --- a/tox.ini +++ b/tox.ini @@ -8,12 +8,10 @@ envlist = py24,py26,py27,py33,py34 [testenv] commands = nosetests {posargs} -deps = - mox - nose +deps = -r{toxinidir}/test-requirements.txt # At some point we should enable this since tox expects it to exist but -# the code will need pep8ising first. +# the code will need pep8ising first. #[testenv:pep8] #commands = flake8 #dep = flake8