# vim: tabstop=4 shiftwidth=4 softtabstop=4 """Unit tests for Token plugins""" import sys import unittest from unittest.mock import patch, mock_open, MagicMock from jwcrypto import jwt, jwk from websockify.token_plugins import ( parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis, ) class ParseSourceArgumentsTestCase(unittest.TestCase): def test_parameterized(self): params = [ ("", [""]), (":", ["", ""]), ("::", ["", "", ""]), ('"', ['"']), ('""', ['""']), ('"""', ['"""']), ('"localhost"', ["localhost"]), ('"localhost":', ["localhost", ""]), ('"localhost"::', ["localhost", "", ""]), ('"local:host"', ["local:host"]), ('"local:host:"pass"', ['"local', "host", "pass"]), ('"local":"host"', ["local", "host"]), ('"local":host"', ["local", 'host"']), ( 'localhost:6379:1:pass"word:"my-app-namespace:dev"', ["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"], ), ] for src, args in params: self.assertEqual(args, parse_source_args(src)) class ReadOnlyTokenFileTestCase(unittest.TestCase): patch("os.path.isdir", MagicMock(return_value=False)) def test_empty(self): plugin = ReadOnlyTokenFile("configfile") config = "" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): result = plugin.lookup("testhost") pyopen.assert_called_once_with("configfile") self.assertIsNone(result) patch("os.path.isdir", MagicMock(return_value=False)) def test_simple(self): plugin = ReadOnlyTokenFile("configfile") config = "testhost: remote_host:remote_port" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): result = plugin.lookup("testhost") pyopen.assert_called_once_with("configfile") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") patch("os.path.isdir", MagicMock(return_value=False)) def test_tabs(self): plugin = ReadOnlyTokenFile("configfile") config = "testhost:\tremote_host:remote_port" pyopen = mock_open(read_data=config) with patch("websockify.token_plugins.open", pyopen, create=True): result = plugin.lookup("testhost") pyopen.assert_called_once_with("configfile") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") class JWSTokenTestCase(unittest.TestCase): def test_asymmetric_jws_token_plugin(self): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") def test_asymmetric_jws_token_plugin_with_illigal_key_exception(self): plugin = JWTTokenApi("wrong.pub") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNone(result) @patch("time.time") def test_jwt_valid_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, ) jwt_token.make_signed_token(key) mock_time.return_value = 150 result = plugin.lookup(jwt_token.serialize()) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("time.time") def test_jwt_early_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, ) jwt_token.make_signed_token(key) mock_time.return_value = 50 result = plugin.lookup(jwt_token.serialize()) self.assertIsNone(result) @patch("time.time") def test_jwt_late_time(self, mock_time): plugin = JWTTokenApi("./tests/fixtures/public.pem") key = jwk.JWK() private_key = open("./tests/fixtures/private.pem", "rb").read() key.import_from_pem(private_key) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200}, ) jwt_token.make_signed_token(key) mock_time.return_value = 250 result = plugin.lookup(jwt_token.serialize()) self.assertIsNone(result) def test_symmetric_jws_token_plugin(self): plugin = JWTTokenApi("./tests/fixtures/symmetric.key") secret = open("./tests/fixtures/symmetric.key").read() key = jwk.JWK() key.import_key(kty="oct", k=secret) jwt_token = jwt.JWT( {"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"} ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") def test_symmetric_jws_token_plugin_with_illigal_key_exception(self): plugin = JWTTokenApi("wrong_sauce") secret = open("./tests/fixtures/symmetric.key").read() key = jwk.JWK() key.import_key(kty="oct", k=secret) jwt_token = jwt.JWT( {"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"} ) jwt_token.make_signed_token(key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNone(result) def test_asymmetric_jwe_token_plugin(self): plugin = JWTTokenApi("./tests/fixtures/private.pem") private_key = jwk.JWK() public_key = jwk.JWK() private_key_data = open("./tests/fixtures/private.pem", "rb").read() public_key_data = open("./tests/fixtures/public.pem", "rb").read() private_key.import_from_pem(private_key_data) public_key.import_from_pem(public_key_data) jwt_token = jwt.JWT( {"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"} ) jwt_token.make_signed_token(private_key) jwe_token = jwt.JWT( header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"}, claims=jwt_token.serialize(), ) jwe_token.make_encrypted_token(public_key) result = plugin.lookup(jwt_token.serialize()) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") class TokenRedisTestCase(unittest.TestCase): def setUp(self): try: import redis except ImportError: patcher = patch.dict(sys.modules, {"redis": MagicMock()}) patcher.start() self.addCleanup(patcher.stop) @patch("redis.Redis") def test_empty(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = None result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNone(result) @patch("redis.Redis") def test_simple(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b'{"host": "remote_host:remote_port"}' result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("redis.Redis") def test_json_token_with_spaces(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b' {"host": "remote_host:remote_port"} ' result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("redis.Redis") def test_text_token(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b"remote_host:remote_port" result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("redis.Redis") def test_text_token_with_spaces(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b" remote_host:remote_port " result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("redis.Redis") def test_invalid_token(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") instance = mock_redis.return_value instance.get.return_value = b'{"host": "remote_host:remote_port" ' result = plugin.lookup("testhost") instance.get.assert_called_once_with("testhost") self.assertIsNone(result) @patch("redis.Redis") def test_token_without_namespace(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234") token = "testhost" def mock_redis_get(key): self.assertEqual(key, token) return b"remote_host:remote_port" instance = mock_redis.return_value instance.get = mock_redis_get result = plugin.lookup(token) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") @patch("redis.Redis") def test_token_with_namespace(self, mock_redis): plugin = TokenRedis("127.0.0.1:1234:::namespace") token = "testhost" def mock_redis_get(key): self.assertEqual(key, "namespace:" + token) return b"remote_host:remote_port" instance = mock_redis.return_value instance.get = mock_redis_get result = plugin.lookup(token) self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") def test_src_only_host(self): plugin = TokenRedis("127.0.0.1") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port(self): plugin = TokenRedis("127.0.0.1:1234") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db(self): plugin = TokenRedis("127.0.0.1:1234:2") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db_pass(self): plugin = TokenRedis("127.0.0.1:1234:2:verysecret") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db_pass_namespace(self): plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "namespace:") def test_src_with_host_empty_port_empty_db_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1:::verysecret") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self): plugin = TokenRedis("127.0.0.1::::") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1:::") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1::") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1:") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self): plugin = TokenRedis("127.0.0.1::::namespace") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "namespace:") def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self): plugin = TokenRedis('127.0.0.1::::"ns1:ns2"') self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "ns1:ns2:") def test_src_with_host_empty_port_db_no_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1::2") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "") def test_src_with_host_port_empty_db_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1:1234::verysecret") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_db_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1::2:verysecret") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, "verysecret") self.assertEqual(plugin._namespace, "") def test_src_with_host_empty_port_db_empty_pass_no_namespace(self): plugin = TokenRedis("127.0.0.1::2:") self.assertEqual(plugin._server, "127.0.0.1") self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "")