websockify/tests/test_token_plugins.py

506 lines
17 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
"""Unit tests for Token plugins"""
import sys
import unittest
from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk
from websockify.token_plugins import (
parse_source_args,
ReadOnlyTokenFile,
JWTTokenApi,
TokenRedis,
)
class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self):
params = [
("", [""]),
(":", ["", ""]),
("::", ["", "", ""]),
('"', ['"']),
('""', ['""']),
('"""', ['"""']),
('"localhost"', ["localhost"]),
('"localhost":', ["localhost", ""]),
('"localhost"::', ["localhost", "", ""]),
('"local:host"', ["local:host"]),
('"local:host:"pass"', ['"local', "host", "pass"]),
('"local":"host"', ["local", "host"]),
('"local":host"', ["local", 'host"']),
(
'localhost:6379:1:pass"word:"my-app-namespace:dev"',
["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"],
),
]
for src, args in params:
self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch("os.path.isdir", MagicMock(return_value=False))
def test_empty(self):
plugin = ReadOnlyTokenFile("configfile")
config = ""
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup("testhost")
pyopen.assert_called_once_with("configfile")
self.assertIsNone(result)
patch("os.path.isdir", MagicMock(return_value=False))
def test_simple(self):
plugin = ReadOnlyTokenFile("configfile")
config = "testhost: remote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup("testhost")
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
patch("os.path.isdir", MagicMock(return_value=False))
def test_tabs(self):
plugin = ReadOnlyTokenFile("configfile")
config = "testhost:\tremote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup("testhost")
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class JWSTokenTestCase(unittest.TestCase):
def test_asymmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
def test_asymmetric_jws_token_plugin_with_illigal_key_exception(self):
plugin = JWTTokenApi("wrong.pub")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
@patch("time.time")
def test_jwt_valid_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 150
result = plugin.lookup(jwt_token.serialize())
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("time.time")
def test_jwt_early_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 50
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
@patch("time.time")
def test_jwt_late_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 250
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
def test_symmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/symmetric.key")
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
def test_symmetric_jws_token_plugin_with_illigal_key_exception(self):
plugin = JWTTokenApi("wrong_sauce")
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
def test_asymmetric_jwe_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/private.pem")
private_key = jwk.JWK()
public_key = jwk.JWK()
private_key_data = open("./tests/fixtures/private.pem", "rb").read()
public_key_data = open("./tests/fixtures/public.pem", "rb").read()
private_key.import_from_pem(private_key_data)
public_key.import_from_pem(public_key_data)
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(private_key)
jwe_token = jwt.JWT(
header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize(),
)
jwe_token.make_encrypted_token(public_key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class TokenRedisTestCase(unittest.TestCase):
def setUp(self):
try:
import redis
except ImportError:
patcher = patch.dict(sys.modules, {"redis": MagicMock()})
patcher.start()
self.addCleanup(patcher.stop)
@patch("redis.Redis")
def test_empty(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = None
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch("redis.Redis")
def test_simple(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port"}'
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("redis.Redis")
def test_json_token_with_spaces(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b' {"host": "remote_host:remote_port"} '
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("redis.Redis")
def test_text_token(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b"remote_host:remote_port"
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("redis.Redis")
def test_text_token_with_spaces(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b" remote_host:remote_port "
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("redis.Redis")
def test_invalid_token(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port" '
result = plugin.lookup("testhost")
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch("redis.Redis")
def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, token)
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch("redis.Redis")
def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis("127.0.0.1:1234:::namespace")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token)
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
def test_src_only_host(self):
plugin = TokenRedis("127.0.0.1")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port(self):
plugin = TokenRedis("127.0.0.1:1234")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db(self):
plugin = TokenRedis("127.0.0.1:1234:2")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass(self):
plugin = TokenRedis("127.0.0.1:1234:2:verysecret")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1:::verysecret")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis("127.0.0.1::::")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1:::")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1::")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1:")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis("127.0.0.1::::namespace")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:")
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1::2")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1:1234::verysecret")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1::2:verysecret")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis("127.0.0.1::2:")
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")