Merge pull request #172 from kanaka/feature/generic-auth-hook
Auth Plugins / Plugin Cleanup / Test Update
This commit is contained in:
commit
ca9fe06184
|
|
@ -9,3 +9,5 @@ other/node_modules
|
|||
.pydevproject
|
||||
target.cfg
|
||||
target.cfg.d
|
||||
.tox
|
||||
*.egg-info
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
language: python
|
||||
python:
|
||||
- 2.6
|
||||
- 2.7
|
||||
- 3.3
|
||||
- 3.4
|
||||
|
||||
install: pip install -r test-requirements.txt
|
||||
|
||||
script: python setup.py nosetests --verbosity=3
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
mox
|
||||
nose
|
||||
|
|
@ -26,201 +26,303 @@ import stubout
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from ssl import SSLError
|
||||
from websockify import websocket as websocket
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
import socket
|
||||
import signal
|
||||
from websockify import websocket
|
||||
|
||||
try:
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
except ImportError:
|
||||
from http.server import SimpleHTTPRequestHandler
|
||||
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
BytesIO = StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class MockConnection(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def makefile(self, mode='r', bufsize=-1):
|
||||
return open(self.path, mode, bufsize)
|
||||
|
||||
|
||||
class WebSocketTestCase(unittest.TestCase):
|
||||
def raise_oserror(*args, **kwargs):
|
||||
raise OSError('fake error')
|
||||
|
||||
def _init_logger(self, tmpdir):
|
||||
name = 'websocket-unittest'
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = True
|
||||
filename = "%s.log" % (name)
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
|
||||
class FakeSocket(object):
|
||||
def __init__(self, data=''):
|
||||
if isinstance(data, bytes):
|
||||
self._data = data
|
||||
else:
|
||||
self._data = data.encode('latin_1')
|
||||
|
||||
def recv(self, amt, flags=None):
|
||||
res = self._data[0:amt]
|
||||
if not (flags & socket.MSG_PEEK):
|
||||
self._data = self._data[amt:]
|
||||
|
||||
return res
|
||||
|
||||
def makefile(self, mode='r', buffsize=None):
|
||||
if 'b' in mode:
|
||||
return BytesIO(self._data)
|
||||
else:
|
||||
return StringIO(self._data.decode('latin_1'))
|
||||
|
||||
|
||||
class WebSocketRequestHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Called automatically before each test."""
|
||||
super(WebSocketTestCase, self).setUp()
|
||||
super(WebSocketRequestHandlerTestCase, self).setUp()
|
||||
self.stubs = stubout.StubOutForTesting()
|
||||
# Temporary dir for test data
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
# Put log somewhere persistent
|
||||
self._init_logger('./')
|
||||
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
||||
# Mock this out cause it screws tests up
|
||||
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
|
||||
self.server = self._get_websockserver(daemon=True,
|
||||
ssl_only=False)
|
||||
self.soc = self.server.socket('localhost')
|
||||
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
|
||||
lambda *args, **kwargs: None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Called automatically after each test."""
|
||||
self.stubs.UnsetAll()
|
||||
shutil.rmtree(self.tmpdir)
|
||||
super(WebSocketTestCase, self).tearDown()
|
||||
os.rmdir(self.tmpdir)
|
||||
super(WebSocketRequestHandlerTestCase, self).tearDown()
|
||||
|
||||
def _get_websockserver(self, **kwargs):
|
||||
return websocket.WebSocketServer(listen_host='localhost',
|
||||
listen_port=80,
|
||||
key=self.tmpdir,
|
||||
web=self.tmpdir,
|
||||
record=self.tmpdir,
|
||||
**kwargs)
|
||||
def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
|
||||
**kwargs):
|
||||
web = kwargs.pop('web', self.tmpdir)
|
||||
return websocket.WebSocketServer(
|
||||
handler_class, listen_host='localhost',
|
||||
listen_port=80, key=self.tmpdir, web=web,
|
||||
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
|
||||
**kwargs)
|
||||
|
||||
def _mock_os_open_oserror(self, file, flags):
|
||||
raise OSError('')
|
||||
def test_normal_get_with_only_upgrade_returns_error(self):
|
||||
server = self._get_server(web=None)
|
||||
handler = websocket.WebSocketRequestHandler(
|
||||
FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
|
||||
|
||||
def _mock_os_close_oserror(self, fd):
|
||||
raise OSError('')
|
||||
def fake_send_response(self, code, message=None):
|
||||
self.last_code = code
|
||||
|
||||
def _mock_os_close_oserror_EBADF(self, fd):
|
||||
raise OSError(errno.EBADF, '')
|
||||
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
|
||||
fake_send_response)
|
||||
|
||||
def _mock_socket(self, *args, **kwargs):
|
||||
return self.soc
|
||||
handler.do_GET()
|
||||
self.assertEqual(handler.last_code, 405)
|
||||
|
||||
def _mock_select(self, rlist, wlist, xlist, timeout=None):
|
||||
return '_mock_select'
|
||||
def test_list_dir_with_file_only_returns_error(self):
|
||||
server = self._get_server(file_only=True)
|
||||
handler = websocket.WebSocketRequestHandler(
|
||||
FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
|
||||
|
||||
def _mock_select_exception(self, rlist, wlist, xlist, timeout=None):
|
||||
raise Exception
|
||||
def fake_send_response(self, code, message=None):
|
||||
self.last_code = code
|
||||
|
||||
def _mock_select_keyboardinterrupt(self, rlist, wlist,
|
||||
xlist, timeout=None):
|
||||
raise KeyboardInterrupt
|
||||
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
|
||||
fake_send_response)
|
||||
|
||||
def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None):
|
||||
sys.exit()
|
||||
handler.path = '/'
|
||||
handler.do_GET()
|
||||
self.assertEqual(handler.last_code, 404)
|
||||
|
||||
def test_daemonize_error(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(os, 'fork', lambda *args: None)
|
||||
|
||||
class WebSocketServerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super(WebSocketServerTestCase, self).setUp()
|
||||
self.stubs = stubout.StubOutForTesting()
|
||||
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
||||
# Mock this out cause it screws tests up
|
||||
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Called automatically after each test."""
|
||||
self.stubs.UnsetAll()
|
||||
os.rmdir(self.tmpdir)
|
||||
super(WebSocketServerTestCase, self).tearDown()
|
||||
|
||||
def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
|
||||
**kwargs):
|
||||
return websocket.WebSocketServer(
|
||||
handler_class, listen_host='localhost',
|
||||
listen_port=80, key=self.tmpdir, web=self.tmpdir,
|
||||
record=self.tmpdir, **kwargs)
|
||||
|
||||
def test_daemonize_raises_error_while_closing_fds(self):
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(os, 'fork', lambda *args: 0)
|
||||
self.stubs.Set(signal, 'signal', lambda *args: None)
|
||||
self.stubs.Set(os, 'setsid', lambda *args: None)
|
||||
self.stubs.Set(os, 'close', self._mock_os_close_oserror)
|
||||
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
|
||||
self.stubs.Set(os, 'close', raise_oserror)
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
||||
|
||||
def test_daemonize_EBADF_error(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(os, 'fork', lambda *args: None)
|
||||
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
|
||||
def raise_oserror_ebadf(fd):
|
||||
raise OSError(errno.EBADF, 'fake error')
|
||||
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(os, 'fork', lambda *args: 0)
|
||||
self.stubs.Set(os, 'setsid', lambda *args: None)
|
||||
self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
|
||||
self.stubs.Set(os, 'open', self._mock_os_open_oserror)
|
||||
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
|
||||
self.stubs.Set(signal, 'signal', lambda *args: None)
|
||||
self.stubs.Set(os, 'close', raise_oserror_ebadf)
|
||||
self.stubs.Set(os, 'open', raise_oserror)
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
||||
|
||||
def test_decode_hybi(self):
|
||||
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
|
||||
self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
|
||||
base64=True)
|
||||
def test_handshake_fails_on_not_ready(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
||||
def test_do_websocket_handshake(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
soc.scheme = 'scheme'
|
||||
headers = {'Sec-WebSocket-Protocol': 'binary',
|
||||
'Sec-WebSocket-Version': '7',
|
||||
'Sec-WebSocket-Key': 'foo'}
|
||||
soc.do_websocket_handshake(headers, '127.0.0.1')
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([], [], [])
|
||||
|
||||
def test_do_handshake(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
self.stubs.Set(select, 'select', self._mock_select)
|
||||
self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
|
||||
self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.assertRaises(
|
||||
websocket.WebSocketServer.EClose, server.do_handshake,
|
||||
FakeSocket(), '127.0.0.1')
|
||||
|
||||
def test_do_handshake_ssl_error(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
def test_empty_handshake_fails(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
||||
def _mock_wrap_socket(*args, **kwargs):
|
||||
from ssl import SSLError
|
||||
raise SSLError('unit test exception')
|
||||
sock = FakeSocket('')
|
||||
|
||||
self.stubs.Set(select, 'select', self._mock_select)
|
||||
self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16')
|
||||
self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket)
|
||||
self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
def test_fallback_SIGCHILD(self):
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
soc.fallback_SIGCHLD(None, None)
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.assertRaises(
|
||||
websocket.WebSocketServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
|
||||
def test_start_server_Exception(self):
|
||||
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
|
||||
def test_handshake_policy_request(self):
|
||||
# TODO(directxman12): implement
|
||||
pass
|
||||
|
||||
def test_handshake_ssl_only_without_ssl_raises_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
|
||||
sock = FakeSocket('some initial data')
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.assertRaises(
|
||||
websocket.WebSocketServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
|
||||
def test_do_handshake_no_ssl(self):
|
||||
class FakeHandler(object):
|
||||
CALLED = False
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).CALLED = True
|
||||
|
||||
FakeHandler.CALLED = False
|
||||
|
||||
server = self._get_server(
|
||||
handler_class=FakeHandler, daemon=True,
|
||||
ssl_only=0, idle_timeout=1)
|
||||
|
||||
sock = FakeSocket('some initial data')
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
|
||||
self.assertTrue(FakeHandler.CALLED, True)
|
||||
|
||||
def test_do_handshake_ssl(self):
|
||||
# TODO(directxman12): implement this
|
||||
pass
|
||||
|
||||
def test_do_handshake_ssl_without_ssl_raises_error(self):
|
||||
# TODO(directxman12): implement this
|
||||
pass
|
||||
|
||||
def test_do_handshake_ssl_without_cert_raises_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
|
||||
cert='afdsfasdafdsafdsafdsafdas')
|
||||
|
||||
sock = FakeSocket("\x16some ssl data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.assertRaises(
|
||||
websocket.WebSocketServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
|
||||
def test_do_handshake_ssl_error_eof_raises_close_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
||||
sock = FakeSocket("\x16some ssl data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
def fake_wrap_socket(*args, **kwargs):
|
||||
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
|
||||
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
|
||||
self.assertRaises(
|
||||
websocket.WebSocketServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
|
||||
def test_fallback_sigchld_handler(self):
|
||||
# TODO(directxman12): implement this
|
||||
pass
|
||||
|
||||
def test_start_server_error(self):
|
||||
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
raise Exception("fake error")
|
||||
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket',
|
||||
lambda *args, **kwargs: sock)
|
||||
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
|
||||
lambda *args, **kwargs: None)
|
||||
self.stubs.Set(select, 'select', self._mock_select_exception)
|
||||
self.assertEqual(None, soc.start_server())
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
server.start_server()
|
||||
|
||||
def test_start_server_KeyboardInterrupt(self):
|
||||
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
|
||||
def test_start_server_keyboardinterrupt(self):
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket',
|
||||
lambda *args, **kwargs: sock)
|
||||
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
|
||||
lambda *args, **kwargs: None)
|
||||
self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt)
|
||||
self.assertEqual(None, soc.start_server())
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
server.start_server()
|
||||
|
||||
def test_start_server_systemexit(self):
|
||||
websocket.ssl = None
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
sys.exit()
|
||||
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket',
|
||||
lambda *args, **kwargs: sock)
|
||||
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
|
||||
lambda *args, **kwargs: None)
|
||||
self.stubs.Set(select, 'select', self._mock_select_systemexit)
|
||||
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1,
|
||||
verbose=True)
|
||||
self.assertEqual(None, soc.start_server())
|
||||
self.stubs.Set(select, 'select', fake_select)
|
||||
server.start_server()
|
||||
|
||||
def test_WSRequestHandle_do_GET_nofile(self):
|
||||
request = 'GET /tmp.txt HTTP/0.9'
|
||||
with tempfile.NamedTemporaryFile() as test_file:
|
||||
test_file.write(request)
|
||||
test_file.flush()
|
||||
test_file.seek(0)
|
||||
con = MockConnection(test_file.name)
|
||||
soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True)
|
||||
soc.path = ''
|
||||
soc.headers = {'upgrade': ''}
|
||||
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
|
||||
lambda *args: None)
|
||||
soc.do_GET()
|
||||
self.assertEqual(404, soc.last_code)
|
||||
|
||||
def test_WSRequestHandle_do_GET_hidden_resource(self):
|
||||
request = 'GET /tmp.txt HTTP/0.9'
|
||||
with tempfile.NamedTemporaryFile() as test_file:
|
||||
test_file.write(request)
|
||||
test_file.flush()
|
||||
test_file.seek(0)
|
||||
con = MockConnection(test_file.name)
|
||||
soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True)
|
||||
soc.path = test_file.name + '?'
|
||||
soc.headers = {'upgrade': ''}
|
||||
soc.webroot = 'no match startswith'
|
||||
self.stubs.Set(SimpleHTTPRequestHandler,
|
||||
'send_response',
|
||||
lambda *args: None)
|
||||
soc.do_GET()
|
||||
self.assertEqual(403, soc.last_code)
|
||||
|
||||
def testsocket_set_keepalive_options(self):
|
||||
def test_socket_set_keepalive_options(self):
|
||||
keepcnt = 12
|
||||
keepidle = 34
|
||||
keepintvl = 56
|
||||
|
||||
sock = self.server.socket('localhost',
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost',
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
|
||||
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPCNT), keepcnt)
|
||||
|
|
@ -229,11 +331,11 @@ class WebSocketTestCase(unittest.TestCase):
|
|||
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPINTVL), keepintvl)
|
||||
|
||||
sock = self.server.socket('localhost',
|
||||
tcp_keepalive=False,
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
sock = server.socket('localhost',
|
||||
tcp_keepalive=False,
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
|
||||
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPCNT), keepcnt)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright(c)2013 NTT corp. All Rights Reserved.
|
||||
# Copyright(c) 2015 Red Hat, Inc All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
|
|
@ -15,113 +15,122 @@
|
|||
# under the License.
|
||||
|
||||
""" Unit tests for websocketproxy """
|
||||
import os
|
||||
import logging
|
||||
import select
|
||||
import shutil
|
||||
import stubout
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import unittest
|
||||
import unittest
|
||||
import socket
|
||||
|
||||
import stubout
|
||||
|
||||
from websockify import websocket
|
||||
from websockify import websocketproxy
|
||||
from websockify import token_plugins
|
||||
from websockify import auth_plugins
|
||||
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
BytesIO = StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class MockSocket(object):
|
||||
def __init__(*args, **kwargs):
|
||||
class FakeSocket(object):
|
||||
def __init__(self, data=''):
|
||||
if isinstance(data, bytes):
|
||||
self._data = data
|
||||
else:
|
||||
self._data = data.encode('latin_1')
|
||||
|
||||
def recv(self, amt, flags=None):
|
||||
res = self._data[0:amt]
|
||||
if not (flags & socket.MSG_PEEK):
|
||||
self._data = self._data[amt:]
|
||||
|
||||
return res
|
||||
|
||||
def makefile(self, mode='r', buffsize=None):
|
||||
if 'b' in mode:
|
||||
return BytesIO(self._data)
|
||||
else:
|
||||
return StringIO(self._data.decode('latin_1'))
|
||||
|
||||
|
||||
class FakeServer(object):
|
||||
class EClose(Exception):
|
||||
pass
|
||||
|
||||
def shutdown(*args):
|
||||
pass
|
||||
|
||||
def close(*args):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketProxyTest(unittest.TestCase):
|
||||
|
||||
def _init_logger(self, tmpdir):
|
||||
name = 'websocket-unittest'
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = True
|
||||
filename = "%s.log" % (name)
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
def __init__(self):
|
||||
self.token_plugin = None
|
||||
self.auth_plugin = None
|
||||
self.wrap_cmd = None
|
||||
self.ssl_target = None
|
||||
self.unix_target = None
|
||||
|
||||
class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Called automatically before each test."""
|
||||
super(WebSocketProxyTest, self).setUp()
|
||||
self.soc = ''
|
||||
super(ProxyRequestHandlerTestCase, self).setUp()
|
||||
self.stubs = stubout.StubOutForTesting()
|
||||
# Temporary dir for test data
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
# Put log somewhere persistent
|
||||
self._init_logger('./')
|
||||
# Mock this out cause it screws tests up
|
||||
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
|
||||
self.handler = websocketproxy.ProxyRequestHandler(
|
||||
FakeSocket(''), "127.0.0.1", FakeServer())
|
||||
self.handler.path = "https://localhost:6080/websockify?token=blah"
|
||||
self.handler.headers = None
|
||||
self.stubs.Set(websocket.WebSocketServer, 'socket',
|
||||
staticmethod(lambda *args, **kwargs: None))
|
||||
|
||||
def tearDown(self):
|
||||
"""Called automatically after each test."""
|
||||
self.stubs.UnsetAll()
|
||||
shutil.rmtree(self.tmpdir)
|
||||
super(WebSocketProxyTest, self).tearDown()
|
||||
super(ProxyRequestHandlerTestCase, self).tearDown()
|
||||
|
||||
def _get_websockproxy(self, **kwargs):
|
||||
return websocketproxy.WebSocketProxy(key=self.tmpdir,
|
||||
web=self.tmpdir,
|
||||
record=self.tmpdir,
|
||||
**kwargs)
|
||||
def test_get_target(self):
|
||||
class TestPlugin(token_plugins.BasePlugin):
|
||||
def lookup(self, token):
|
||||
return ("some host", "some port")
|
||||
|
||||
def test_run_wrap_cmd(self):
|
||||
web_socket_proxy = self._get_websockproxy()
|
||||
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
|
||||
host, port = self.handler.get_target(
|
||||
TestPlugin(None), self.handler.path)
|
||||
|
||||
def mock_Popen(*args, **kwargs):
|
||||
return '_mock_cmd'
|
||||
self.assertEqual(host, "some host")
|
||||
self.assertEqual(port, "some port")
|
||||
|
||||
self.stubs.Set(subprocess, 'Popen', mock_Popen)
|
||||
web_socket_proxy.run_wrap_cmd()
|
||||
self.assertEquals(web_socket_proxy.spawn_message, True)
|
||||
def test_get_target_raises_error_on_unknown_token(self):
|
||||
class TestPlugin(token_plugins.BasePlugin):
|
||||
def lookup(self, token):
|
||||
return None
|
||||
|
||||
def test_started(self):
|
||||
web_socket_proxy = self._get_websockproxy()
|
||||
web_socket_proxy.__dict__["spawn_message"] = False
|
||||
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
|
||||
self.assertRaises(FakeServer.EClose, self.handler.get_target,
|
||||
TestPlugin(None), "https://localhost:6080/websockify?token=blah")
|
||||
|
||||
def mock_run_wrap_cmd(*args, **kwargs):
|
||||
web_socket_proxy.__dict__["spawn_message"] = True
|
||||
def test_token_plugin(self):
|
||||
class TestPlugin(token_plugins.BasePlugin):
|
||||
def lookup(self, token):
|
||||
return (self.source + token).split(',')
|
||||
|
||||
self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd)
|
||||
web_socket_proxy.started()
|
||||
self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True)
|
||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
||||
lambda *args, **kwargs: None)
|
||||
|
||||
def test_poll(self):
|
||||
web_socket_proxy = self._get_websockproxy()
|
||||
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
|
||||
web_socket_proxy.__dict__["wrap_mode"] = "respawn"
|
||||
web_socket_proxy.__dict__["wrap_times"] = [99999999]
|
||||
web_socket_proxy.__dict__["spawn_message"] = True
|
||||
web_socket_proxy.__dict__["cmd"] = None
|
||||
self.stubs.Set(time, 'time', lambda: 100000000.000)
|
||||
web_socket_proxy.poll()
|
||||
self.assertEquals(web_socket_proxy.spawn_message, False)
|
||||
self.handler.server.token_plugin = TestPlugin("somehost,")
|
||||
self.handler.new_websocket_client()
|
||||
|
||||
def test_new_client(self):
|
||||
web_socket_proxy = self._get_websockproxy()
|
||||
web_socket_proxy.__dict__["verbose"] = "verbose"
|
||||
web_socket_proxy.__dict__["daemon"] = None
|
||||
web_socket_proxy.__dict__["client"] = "client"
|
||||
self.assertEqual(self.handler.server.target_host, "somehost")
|
||||
self.assertEqual(self.handler.server.target_port, "blah")
|
||||
|
||||
self.stubs.Set(web_socket_proxy, 'socket', MockSocket)
|
||||
def test_auth_plugin(self):
|
||||
class TestPlugin(auth_plugins.BasePlugin):
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
if target_host == self.source:
|
||||
raise auth_plugins.AuthenticationError("some error")
|
||||
|
||||
def mock_select(*args, **kwargs):
|
||||
ins = None
|
||||
outs = None
|
||||
excepts = "excepts"
|
||||
return ins, outs, excepts
|
||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
||||
staticmethod(lambda *args, **kwargs: None))
|
||||
|
||||
self.handler.server.auth_plugin = TestPlugin("somehost")
|
||||
self.handler.server.target_host = "somehost"
|
||||
self.handler.server.target_port = "someport"
|
||||
|
||||
self.assertRaises(auth_plugins.AuthenticationError,
|
||||
self.handler.new_websocket_client)
|
||||
|
||||
self.handler.server.target_host = "someotherhost"
|
||||
self.handler.new_websocket_client()
|
||||
|
||||
self.stubs.Set(select, 'select', mock_select)
|
||||
self.assertRaises(Exception, web_socket_proxy.new_websocket_client)
|
||||
|
|
|
|||
|
|
@ -4,17 +4,14 @@
|
|||
# and then run "tox" from this directory.
|
||||
|
||||
[tox]
|
||||
envlist = py24,py25,py26,py27,py30
|
||||
setupdir = ../
|
||||
envlist = py24,py26,py27,py33,py34
|
||||
|
||||
[testenv]
|
||||
commands = nosetests {posargs}
|
||||
deps =
|
||||
mox
|
||||
nose
|
||||
deps = -r{toxinidir}/test-requirements.txt
|
||||
|
||||
# At some point we should enable this since tox epdctes it to exist but
|
||||
# the code will need pep8ising first.
|
||||
# At some point we should enable this since tox expects it to exist but
|
||||
# the code will need pep8ising first.
|
||||
#[testenv:pep8]
|
||||
#commands = flake8
|
||||
#dep = flake8
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
class BasePlugin(object):
|
||||
def __init__(self, src=None):
|
||||
self.source = src
|
||||
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidOriginError(AuthenticationError):
|
||||
def __init__(self, expected, actual):
|
||||
self.expected_origin = expected
|
||||
self.actual_origin = actual
|
||||
|
||||
super(InvalidOriginError, self).__init__(
|
||||
"Invalid Origin Header: Expected one of "
|
||||
"%s, got '%s'" % (expected, actual))
|
||||
|
||||
|
||||
class ExpectOrigin(object):
|
||||
def __init__(self, src=None):
|
||||
if src is None:
|
||||
self.source = []
|
||||
else:
|
||||
self.source = src.split()
|
||||
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
origin = headers.getheader('Origin', None)
|
||||
if origin is None or origin not in self.source:
|
||||
raise InvalidOriginError(expected=self.source, actual=origin)
|
||||
|
|
@ -12,6 +12,10 @@ class ReadOnlyTokenFile(BasePlugin):
|
|||
# source is a token file with lines like
|
||||
# token: host:port
|
||||
# or a directory of such files
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReadOnlyTokenFile, self).__init__(*args, **kwargs)
|
||||
self._targets = None
|
||||
|
||||
def _load_targets(self):
|
||||
if os.path.isdir(self.source):
|
||||
cfg_files = [os.path.join(self.source, f) for
|
||||
|
|
|
|||
|
|
@ -790,7 +790,7 @@ class WebSocketServer(object):
|
|||
handshake = sock.recv(1024, socket.MSG_PEEK)
|
||||
#self.msg("Handshake [%s]" % handshake)
|
||||
|
||||
if handshake == "":
|
||||
if not handshake:
|
||||
raise self.EClose("ignoring empty handshake")
|
||||
|
||||
elif handshake.startswith(s2b("<policy-file-request/>")):
|
||||
|
|
|
|||
|
|
@ -47,6 +47,11 @@ Traffic Legend:
|
|||
if self.server.token_plugin:
|
||||
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
|
||||
|
||||
if self.server.auth_plugin:
|
||||
self.server.auth_plugin.authenticate(
|
||||
headers=self.headers, target_host=self.server.target_host,
|
||||
target_port=self.server.target_port)
|
||||
|
||||
# Connect to the target
|
||||
if self.server.wrap_cmd:
|
||||
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
|
||||
|
|
@ -194,21 +199,8 @@ class WebSocketProxy(websocket.WebSocketServer):
|
|||
self.ssl_target = kwargs.pop('ssl_target', None)
|
||||
self.heartbeat = kwargs.pop('heartbeat', None)
|
||||
|
||||
token_plugin = kwargs.pop('token_plugin', None)
|
||||
token_source = kwargs.pop('token_source', None)
|
||||
|
||||
if token_plugin is not None:
|
||||
if '.' not in token_plugin:
|
||||
token_plugin = 'websockify.token_plugins.%s' % token_plugin
|
||||
|
||||
token_plugin_module, token_plugin_cls = token_plugin.rsplit('.', 1)
|
||||
|
||||
__import__(token_plugin_module)
|
||||
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
|
||||
|
||||
self.token_plugin = token_plugin_cls(token_source)
|
||||
else:
|
||||
self.token_plugin = None
|
||||
self.token_plugin = kwargs.pop('token_plugin', None)
|
||||
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
||||
|
||||
# Last 3 timestamps command was run
|
||||
self.wrap_times = [0, 0, 0]
|
||||
|
|
@ -381,6 +373,12 @@ def websockify_init():
|
|||
parser.add_option("--token-source", default=None, metavar="ARG",
|
||||
help="an argument to be passed to the token plugin"
|
||||
"on instantiation")
|
||||
parser.add_option("--auth-plugin", default=None, metavar="PLUGIN",
|
||||
help="use the given Python class to determine if "
|
||||
"a connection is allowed")
|
||||
parser.add_option("--auth-source", default=None, metavar="ARG",
|
||||
help="an argument to be passed to the auth plugin"
|
||||
"on instantiation")
|
||||
parser.add_option("--auto-pong", action="store_true",
|
||||
help="Automatically respond to ping frames with a pong")
|
||||
parser.add_option("--heartbeat", type=int, default=0,
|
||||
|
|
@ -394,6 +392,10 @@ def websockify_init():
|
|||
if opts.token_source and not opts.token_plugin:
|
||||
parser.error("You must use --token-plugin to use --token-source")
|
||||
|
||||
if opts.auth_source and not opts.auth_plugin:
|
||||
parser.error("You must use --auth-plugin to use --auth-source")
|
||||
|
||||
|
||||
# Transform to absolute path as daemon may chdir
|
||||
if opts.target_cfg:
|
||||
opts.target_cfg = os.path.abspath(opts.target_cfg)
|
||||
|
|
@ -442,6 +444,31 @@ def websockify_init():
|
|||
try: opts.target_port = int(opts.target_port)
|
||||
except: parser.error("Error parsing target port")
|
||||
|
||||
if opts.token_plugin is not None:
|
||||
if '.' not in opts.token_plugin:
|
||||
opts.token_plugin = (
|
||||
'websockify.token_plugins.%s' % opts.token_plugin)
|
||||
|
||||
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
|
||||
|
||||
__import__(token_plugin_module)
|
||||
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
|
||||
|
||||
opts.token_plugin = token_plugin_cls(opts.token_source)
|
||||
del opts.token_source
|
||||
|
||||
if opts.auth_plugin is not None:
|
||||
if '.' not in opts.auth_plugin:
|
||||
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
|
||||
|
||||
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
|
||||
|
||||
__import__(auth_plugin_module)
|
||||
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
|
||||
|
||||
opts.auth_plugin = auth_plugin_cls(opts.auth_source)
|
||||
del opts.auth_source
|
||||
|
||||
# Create and start the WebSockets proxy
|
||||
libserver = opts.libserver
|
||||
del opts.libserver
|
||||
|
|
@ -470,10 +497,11 @@ class LibProxyServer(ForkingMixIn, HTTPServer):
|
|||
self.unix_target = kwargs.pop('unix_target', None)
|
||||
self.ssl_target = kwargs.pop('ssl_target', None)
|
||||
self.token_plugin = kwargs.pop('token_plugin', None)
|
||||
self.token_source = kwargs.pop('token_source', None)
|
||||
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
||||
self.heartbeat = kwargs.pop('heartbeat', None)
|
||||
|
||||
self.token_plugin = None
|
||||
self.auth_plugin = None
|
||||
self.daemon = False
|
||||
|
||||
# Server configuration
|
||||
|
|
|
|||
Loading…
Reference in New Issue