Merge pull request #172 from kanaka/feature/generic-auth-hook

Auth Plugins / Plugin Cleanup / Test Update
This commit is contained in:
Solly 2015-05-13 13:40:00 -04:00
commit ca9fe06184
10 changed files with 448 additions and 261 deletions

2
.gitignore vendored
View File

@ -9,3 +9,5 @@ other/node_modules
.pydevproject .pydevproject
target.cfg target.cfg
target.cfg.d target.cfg.d
.tox
*.egg-info

10
.travis.yml Normal file
View File

@ -0,0 +1,10 @@
language: python
python:
- 2.6
- 2.7
- 3.3
- 3.4
install: pip install -r test-requirements.txt
script: python setup.py nosetests --verbosity=3

2
test-requirements.txt Normal file
View File

@ -0,0 +1,2 @@
mox
nose

View File

@ -26,201 +26,303 @@ import stubout
import sys import sys
import tempfile import tempfile
import unittest import unittest
from ssl import SSLError import socket
from websockify import websocket as websocket import signal
from SimpleHTTPServer import SimpleHTTPRequestHandler from websockify import websocket
try:
from SimpleHTTPServer import SimpleHTTPRequestHandler
except ImportError:
from http.server import SimpleHTTPRequestHandler
try:
from StringIO import StringIO
BytesIO = StringIO
except ImportError:
from io import StringIO
from io import BytesIO
class MockConnection(object):
def __init__(self, path):
self.path = path
def makefile(self, mode='r', bufsize=-1):
return open(self.path, mode, bufsize)
class WebSocketTestCase(unittest.TestCase): def raise_oserror(*args, **kwargs):
raise OSError('fake error')
def _init_logger(self, tmpdir):
name = 'websocket-unittest'
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = True
filename = "%s.log" % (name)
handler = logging.FileHandler(filename)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
class FakeSocket(object):
def __init__(self, data=''):
if isinstance(data, bytes):
self._data = data
else:
self._data = data.encode('latin_1')
def recv(self, amt, flags=None):
res = self._data[0:amt]
if not (flags & socket.MSG_PEEK):
self._data = self._data[amt:]
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
class WebSocketRequestHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
"""Called automatically before each test.""" super(WebSocketRequestHandlerTestCase, self).setUp()
super(WebSocketTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
# Temporary dir for test data self.tmpdir = tempfile.mkdtemp('-websockify-tests')
self.tmpdir = tempfile.mkdtemp()
# Put log somewhere persistent
self._init_logger('./')
# Mock this out cause it screws tests up # Mock this out cause it screws tests up
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
self.server = self._get_websockserver(daemon=True, self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
ssl_only=False) lambda *args, **kwargs: None)
self.soc = self.server.socket('localhost')
def tearDown(self): def tearDown(self):
"""Called automatically after each test.""" """Called automatically after each test."""
self.stubs.UnsetAll() self.stubs.UnsetAll()
shutil.rmtree(self.tmpdir) os.rmdir(self.tmpdir)
super(WebSocketTestCase, self).tearDown() super(WebSocketRequestHandlerTestCase, self).tearDown()
def _get_websockserver(self, **kwargs): def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
return websocket.WebSocketServer(listen_host='localhost', **kwargs):
listen_port=80, web = kwargs.pop('web', self.tmpdir)
key=self.tmpdir, return websocket.WebSocketServer(
web=self.tmpdir, handler_class, listen_host='localhost',
record=self.tmpdir, listen_port=80, key=self.tmpdir, web=web,
**kwargs) record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
**kwargs)
def _mock_os_open_oserror(self, file, flags): def test_normal_get_with_only_upgrade_returns_error(self):
raise OSError('') server = self._get_server(web=None)
handler = websocket.WebSocketRequestHandler(
FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
def _mock_os_close_oserror(self, fd): def fake_send_response(self, code, message=None):
raise OSError('') self.last_code = code
def _mock_os_close_oserror_EBADF(self, fd): self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
raise OSError(errno.EBADF, '') fake_send_response)
def _mock_socket(self, *args, **kwargs): handler.do_GET()
return self.soc self.assertEqual(handler.last_code, 405)
def _mock_select(self, rlist, wlist, xlist, timeout=None): def test_list_dir_with_file_only_returns_error(self):
return '_mock_select' server = self._get_server(file_only=True)
handler = websocket.WebSocketRequestHandler(
FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
def _mock_select_exception(self, rlist, wlist, xlist, timeout=None): def fake_send_response(self, code, message=None):
raise Exception self.last_code = code
def _mock_select_keyboardinterrupt(self, rlist, wlist, self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
xlist, timeout=None): fake_send_response)
raise KeyboardInterrupt
def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None): handler.path = '/'
sys.exit() handler.do_GET()
self.assertEqual(handler.last_code, 404)
def test_daemonize_error(self):
soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) class WebSocketServerTestCase(unittest.TestCase):
self.stubs.Set(os, 'fork', lambda *args: None) def setUp(self):
super(WebSocketServerTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
# Mock this out cause it screws tests up
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll()
os.rmdir(self.tmpdir)
super(WebSocketServerTestCase, self).tearDown()
def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
**kwargs):
return websocket.WebSocketServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=self.tmpdir,
record=self.tmpdir, **kwargs)
def test_daemonize_raises_error_while_closing_fds(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: 0)
self.stubs.Set(signal, 'signal', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror) self.stubs.Set(os, 'close', raise_oserror)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
def test_daemonize_EBADF_error(self): def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) def raise_oserror_ebadf(fd):
self.stubs.Set(os, 'fork', lambda *args: None) raise OSError(errno.EBADF, 'fake error')
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: 0)
self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) self.stubs.Set(signal, 'signal', lambda *args: None)
self.stubs.Set(os, 'open', self._mock_os_open_oserror) self.stubs.Set(os, 'close', raise_oserror_ebadf)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') self.stubs.Set(os, 'open', raise_oserror)
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
def test_decode_hybi(self): def test_handshake_fails_on_not_ready(self):
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
base64=True)
def test_do_websocket_handshake(self): def fake_select(rlist, wlist, xlist, timeout=None):
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) return ([], [], [])
soc.scheme = 'scheme'
headers = {'Sec-WebSocket-Protocol': 'binary',
'Sec-WebSocket-Version': '7',
'Sec-WebSocket-Key': 'foo'}
soc.do_websocket_handshake(headers, '127.0.0.1')
def test_do_handshake(self): self.stubs.Set(select, 'select', fake_select)
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) self.assertRaises(
self.stubs.Set(select, 'select', self._mock_select) websocket.WebSocketServer.EClose, server.do_handshake,
self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') FakeSocket(), '127.0.0.1')
self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
def test_do_handshake_ssl_error(self): def test_empty_handshake_fails(self):
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
def _mock_wrap_socket(*args, **kwargs): sock = FakeSocket('')
from ssl import SSLError
raise SSLError('unit test exception')
self.stubs.Set(select, 'select', self._mock_select) def fake_select(rlist, wlist, xlist, timeout=None):
self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16') return ([sock], [], [])
self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket)
self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
def test_fallback_SIGCHILD(self): self.stubs.Set(select, 'select', fake_select)
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) self.assertRaises(
soc.fallback_SIGCHLD(None, None) websocket.WebSocketServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_start_server_Exception(self): def test_handshake_policy_request(self):
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) # TODO(directxman12): implement
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) pass
def test_handshake_ssl_only_without_ssl_raises_error(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
sock = FakeSocket('some initial data')
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
self.stubs.Set(select, 'select', fake_select)
self.assertRaises(
websocket.WebSocketServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_do_handshake_no_ssl(self):
class FakeHandler(object):
CALLED = False
def __init__(self, *args, **kwargs):
type(self).CALLED = True
FakeHandler.CALLED = False
server = self._get_server(
handler_class=FakeHandler, daemon=True,
ssl_only=0, idle_timeout=1)
sock = FakeSocket('some initial data')
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
self.stubs.Set(select, 'select', fake_select)
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
self.assertTrue(FakeHandler.CALLED, True)
def test_do_handshake_ssl(self):
# TODO(directxman12): implement this
pass
def test_do_handshake_ssl_without_ssl_raises_error(self):
# TODO(directxman12): implement this
pass
def test_do_handshake_ssl_without_cert_raises_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
cert='afdsfasdafdsafdsafdsafdas')
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
self.stubs.Set(select, 'select', fake_select)
self.assertRaises(
websocket.WebSocketServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_do_handshake_ssl_error_eof_raises_close_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
def fake_wrap_socket(*args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
self.stubs.Set(select, 'select', fake_select)
self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
self.assertRaises(
websocket.WebSocketServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_fallback_sigchld_handler(self):
# TODO(directxman12): implement this
pass
def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error")
self.stubs.Set(websocket.WebSocketServer, 'socket',
lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_exception) self.stubs.Set(select, 'select', fake_select)
self.assertEqual(None, soc.start_server()) server.start_server()
def test_start_server_KeyboardInterrupt(self): def test_start_server_keyboardinterrupt(self):
soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt
self.stubs.Set(websocket.WebSocketServer, 'socket',
lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt) self.stubs.Set(select, 'select', fake_select)
self.assertEqual(None, soc.start_server()) server.start_server()
def test_start_server_systemexit(self): def test_start_server_systemexit(self):
websocket.ssl = None server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit()
self.stubs.Set(websocket.WebSocketServer, 'socket',
lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_systemexit) self.stubs.Set(select, 'select', fake_select)
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1, server.start_server()
verbose=True)
self.assertEqual(None, soc.start_server())
def test_WSRequestHandle_do_GET_nofile(self): def test_socket_set_keepalive_options(self):
request = 'GET /tmp.txt HTTP/0.9'
with tempfile.NamedTemporaryFile() as test_file:
test_file.write(request)
test_file.flush()
test_file.seek(0)
con = MockConnection(test_file.name)
soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True)
soc.path = ''
soc.headers = {'upgrade': ''}
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
lambda *args: None)
soc.do_GET()
self.assertEqual(404, soc.last_code)
def test_WSRequestHandle_do_GET_hidden_resource(self):
request = 'GET /tmp.txt HTTP/0.9'
with tempfile.NamedTemporaryFile() as test_file:
test_file.write(request)
test_file.flush()
test_file.seek(0)
con = MockConnection(test_file.name)
soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True)
soc.path = test_file.name + '?'
soc.headers = {'upgrade': ''}
soc.webroot = 'no match startswith'
self.stubs.Set(SimpleHTTPRequestHandler,
'send_response',
lambda *args: None)
soc.do_GET()
self.assertEqual(403, soc.last_code)
def testsocket_set_keepalive_options(self):
keepcnt = 12 keepcnt = 12
keepidle = 34 keepidle = 34
keepintvl = 56 keepintvl = 56
sock = self.server.socket('localhost', server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
tcp_keepcnt=keepcnt, sock = server.socket('localhost',
tcp_keepidle=keepidle, tcp_keepcnt=keepcnt,
tcp_keepintvl=keepintvl) tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
self.assertEqual(sock.getsockopt(socket.SOL_TCP, self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt) socket.TCP_KEEPCNT), keepcnt)
@ -229,11 +331,11 @@ class WebSocketTestCase(unittest.TestCase):
self.assertEqual(sock.getsockopt(socket.SOL_TCP, self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl) socket.TCP_KEEPINTVL), keepintvl)
sock = self.server.socket('localhost', sock = server.socket('localhost',
tcp_keepalive=False, tcp_keepalive=False,
tcp_keepcnt=keepcnt, tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle, tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl) tcp_keepintvl=keepintvl)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt) socket.TCP_KEEPCNT), keepcnt)

View File

@ -1,6 +1,6 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4 # vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright(c)2013 NTT corp. All Rights Reserved. # Copyright(c) 2015 Red Hat, Inc All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -15,113 +15,122 @@
# under the License. # under the License.
""" Unit tests for websocketproxy """ """ Unit tests for websocketproxy """
import os
import logging
import select
import shutil
import stubout
import subprocess
import tempfile
import time
import unittest import unittest
import unittest
import socket
import stubout
from websockify import websocket
from websockify import websocketproxy from websockify import websocketproxy
from websockify import token_plugins
from websockify import auth_plugins
try:
from StringIO import StringIO
BytesIO = StringIO
except ImportError:
from io import StringIO
from io import BytesIO
class MockSocket(object): class FakeSocket(object):
def __init__(*args, **kwargs): def __init__(self, data=''):
if isinstance(data, bytes):
self._data = data
else:
self._data = data.encode('latin_1')
def recv(self, amt, flags=None):
res = self._data[0:amt]
if not (flags & socket.MSG_PEEK):
self._data = self._data[amt:]
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
class FakeServer(object):
class EClose(Exception):
pass pass
def shutdown(*args): def __init__(self):
pass self.token_plugin = None
self.auth_plugin = None
def close(*args): self.wrap_cmd = None
pass self.ssl_target = None
self.unix_target = None
class WebSocketProxyTest(unittest.TestCase):
def _init_logger(self, tmpdir):
name = 'websocket-unittest'
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = True
filename = "%s.log" % (name)
handler = logging.FileHandler(filename)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
"""Called automatically before each test.""" super(ProxyRequestHandlerTestCase, self).setUp()
super(WebSocketProxyTest, self).setUp()
self.soc = ''
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
# Temporary dir for test data self.handler = websocketproxy.ProxyRequestHandler(
self.tmpdir = tempfile.mkdtemp() FakeSocket(''), "127.0.0.1", FakeServer())
# Put log somewhere persistent self.handler.path = "https://localhost:6080/websockify?token=blah"
self._init_logger('./') self.handler.headers = None
# Mock this out cause it screws tests up self.stubs.Set(websocket.WebSocketServer, 'socket',
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) staticmethod(lambda *args, **kwargs: None))
def tearDown(self): def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll() self.stubs.UnsetAll()
shutil.rmtree(self.tmpdir) super(ProxyRequestHandlerTestCase, self).tearDown()
super(WebSocketProxyTest, self).tearDown()
def _get_websockproxy(self, **kwargs): def test_get_target(self):
return websocketproxy.WebSocketProxy(key=self.tmpdir, class TestPlugin(token_plugins.BasePlugin):
web=self.tmpdir, def lookup(self, token):
record=self.tmpdir, return ("some host", "some port")
**kwargs)
def test_run_wrap_cmd(self): host, port = self.handler.get_target(
web_socket_proxy = self._get_websockproxy() TestPlugin(None), self.handler.path)
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
def mock_Popen(*args, **kwargs): self.assertEqual(host, "some host")
return '_mock_cmd' self.assertEqual(port, "some port")
self.stubs.Set(subprocess, 'Popen', mock_Popen) def test_get_target_raises_error_on_unknown_token(self):
web_socket_proxy.run_wrap_cmd() class TestPlugin(token_plugins.BasePlugin):
self.assertEquals(web_socket_proxy.spawn_message, True) def lookup(self, token):
return None
def test_started(self): self.assertRaises(FakeServer.EClose, self.handler.get_target,
web_socket_proxy = self._get_websockproxy() TestPlugin(None), "https://localhost:6080/websockify?token=blah")
web_socket_proxy.__dict__["spawn_message"] = False
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
def mock_run_wrap_cmd(*args, **kwargs): def test_token_plugin(self):
web_socket_proxy.__dict__["spawn_message"] = True class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token):
return (self.source + token).split(',')
self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd) self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
web_socket_proxy.started() lambda *args, **kwargs: None)
self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True)
def test_poll(self): self.handler.server.token_plugin = TestPlugin("somehost,")
web_socket_proxy = self._get_websockproxy() self.handler.new_websocket_client()
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
web_socket_proxy.__dict__["wrap_mode"] = "respawn"
web_socket_proxy.__dict__["wrap_times"] = [99999999]
web_socket_proxy.__dict__["spawn_message"] = True
web_socket_proxy.__dict__["cmd"] = None
self.stubs.Set(time, 'time', lambda: 100000000.000)
web_socket_proxy.poll()
self.assertEquals(web_socket_proxy.spawn_message, False)
def test_new_client(self): self.assertEqual(self.handler.server.target_host, "somehost")
web_socket_proxy = self._get_websockproxy() self.assertEqual(self.handler.server.target_port, "blah")
web_socket_proxy.__dict__["verbose"] = "verbose"
web_socket_proxy.__dict__["daemon"] = None
web_socket_proxy.__dict__["client"] = "client"
self.stubs.Set(web_socket_proxy, 'socket', MockSocket) def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
if target_host == self.source:
raise auth_plugins.AuthenticationError("some error")
def mock_select(*args, **kwargs): self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
ins = None staticmethod(lambda *args, **kwargs: None))
outs = None
excepts = "excepts" self.handler.server.auth_plugin = TestPlugin("somehost")
return ins, outs, excepts self.handler.server.target_host = "somehost"
self.handler.server.target_port = "someport"
self.assertRaises(auth_plugins.AuthenticationError,
self.handler.new_websocket_client)
self.handler.server.target_host = "someotherhost"
self.handler.new_websocket_client()
self.stubs.Set(select, 'select', mock_select)
self.assertRaises(Exception, web_socket_proxy.new_websocket_client)

View File

@ -4,17 +4,14 @@
# and then run "tox" from this directory. # and then run "tox" from this directory.
[tox] [tox]
envlist = py24,py25,py26,py27,py30 envlist = py24,py26,py27,py33,py34
setupdir = ../
[testenv] [testenv]
commands = nosetests {posargs} commands = nosetests {posargs}
deps = deps = -r{toxinidir}/test-requirements.txt
mox
nose
# At some point we should enable this since tox epdctes it to exist but # At some point we should enable this since tox expects it to exist but
# the code will need pep8ising first. # the code will need pep8ising first.
#[testenv:pep8] #[testenv:pep8]
#commands = flake8 #commands = flake8
#dep = flake8 #dep = flake8

View File

@ -0,0 +1,33 @@
class BasePlugin(object):
def __init__(self, src=None):
self.source = src
def authenticate(self, headers, target_host, target_port):
pass
class AuthenticationError(Exception):
pass
class InvalidOriginError(AuthenticationError):
def __init__(self, expected, actual):
self.expected_origin = expected
self.actual_origin = actual
super(InvalidOriginError, self).__init__(
"Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
class ExpectOrigin(object):
def __init__(self, src=None):
if src is None:
self.source = []
else:
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
origin = headers.getheader('Origin', None)
if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin)

View File

@ -12,6 +12,10 @@ class ReadOnlyTokenFile(BasePlugin):
# source is a token file with lines like # source is a token file with lines like
# token: host:port # token: host:port
# or a directory of such files # or a directory of such files
def __init__(self, *args, **kwargs):
super(ReadOnlyTokenFile, self).__init__(*args, **kwargs)
self._targets = None
def _load_targets(self): def _load_targets(self):
if os.path.isdir(self.source): if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for cfg_files = [os.path.join(self.source, f) for

View File

@ -790,7 +790,7 @@ class WebSocketServer(object):
handshake = sock.recv(1024, socket.MSG_PEEK) handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake) #self.msg("Handshake [%s]" % handshake)
if handshake == "": if not handshake:
raise self.EClose("ignoring empty handshake") raise self.EClose("ignoring empty handshake")
elif handshake.startswith(s2b("<policy-file-request/>")): elif handshake.startswith(s2b("<policy-file-request/>")):

View File

@ -47,6 +47,11 @@ Traffic Legend:
if self.server.token_plugin: if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
if self.server.auth_plugin:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
# Connect to the target # Connect to the target
if self.server.wrap_cmd: if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
@ -194,21 +199,8 @@ class WebSocketProxy(websocket.WebSocketServer):
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop('ssl_target', None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop('heartbeat', None)
token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop('token_plugin', None)
token_source = kwargs.pop('token_source', None) self.auth_plugin = kwargs.pop('auth_plugin', None)
if token_plugin is not None:
if '.' not in token_plugin:
token_plugin = 'websockify.token_plugins.%s' % token_plugin
token_plugin_module, token_plugin_cls = token_plugin.rsplit('.', 1)
__import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
self.token_plugin = token_plugin_cls(token_source)
else:
self.token_plugin = None
# Last 3 timestamps command was run # Last 3 timestamps command was run
self.wrap_times = [0, 0, 0] self.wrap_times = [0, 0, 0]
@ -381,6 +373,12 @@ def websockify_init():
parser.add_option("--token-source", default=None, metavar="ARG", parser.add_option("--token-source", default=None, metavar="ARG",
help="an argument to be passed to the token plugin" help="an argument to be passed to the token plugin"
"on instantiation") "on instantiation")
parser.add_option("--auth-plugin", default=None, metavar="PLUGIN",
help="use the given Python class to determine if "
"a connection is allowed")
parser.add_option("--auth-source", default=None, metavar="ARG",
help="an argument to be passed to the auth plugin"
"on instantiation")
parser.add_option("--auto-pong", action="store_true", parser.add_option("--auto-pong", action="store_true",
help="Automatically respond to ping frames with a pong") help="Automatically respond to ping frames with a pong")
parser.add_option("--heartbeat", type=int, default=0, parser.add_option("--heartbeat", type=int, default=0,
@ -394,6 +392,10 @@ def websockify_init():
if opts.token_source and not opts.token_plugin: if opts.token_source and not opts.token_plugin:
parser.error("You must use --token-plugin to use --token-source") parser.error("You must use --token-plugin to use --token-source")
if opts.auth_source and not opts.auth_plugin:
parser.error("You must use --auth-plugin to use --auth-source")
# Transform to absolute path as daemon may chdir # Transform to absolute path as daemon may chdir
if opts.target_cfg: if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg) opts.target_cfg = os.path.abspath(opts.target_cfg)
@ -442,6 +444,31 @@ def websockify_init():
try: opts.target_port = int(opts.target_port) try: opts.target_port = int(opts.target_port)
except: parser.error("Error parsing target port") except: parser.error("Error parsing target port")
if opts.token_plugin is not None:
if '.' not in opts.token_plugin:
opts.token_plugin = (
'websockify.token_plugins.%s' % opts.token_plugin)
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
__import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
opts.token_plugin = token_plugin_cls(opts.token_source)
del opts.token_source
if opts.auth_plugin is not None:
if '.' not in opts.auth_plugin:
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
__import__(auth_plugin_module)
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
opts.auth_plugin = auth_plugin_cls(opts.auth_source)
del opts.auth_source
# Create and start the WebSockets proxy # Create and start the WebSockets proxy
libserver = opts.libserver libserver = opts.libserver
del opts.libserver del opts.libserver
@ -470,10 +497,11 @@ class LibProxyServer(ForkingMixIn, HTTPServer):
self.unix_target = kwargs.pop('unix_target', None) self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop('ssl_target', None)
self.token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop('token_plugin', None)
self.token_source = kwargs.pop('token_source', None) self.auth_plugin = kwargs.pop('auth_plugin', None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop('heartbeat', None)
self.token_plugin = None self.token_plugin = None
self.auth_plugin = None
self.daemon = False self.daemon = False
# Server configuration # Server configuration