diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index a63e11b..a9fd256 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -7,7 +7,29 @@ import unittest from unittest.mock import patch, mock_open, MagicMock from jwcrypto import jwt, jwk -from websockify.token_plugins import ReadOnlyTokenFile, JWTTokenApi, TokenRedis +from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis + +class ParseSourceArgumentsTestCase(unittest.TestCase): + def test_parameterized(self): + params = [ + ('', ['']), + (':', ['', '']), + ('::', ['', '', '']), + ('"', ['"']), + ('""', ['""']), + ('"""', ['"""']), + ('"localhost"', ['localhost']), + ('"localhost":', ['localhost', '']), + ('"localhost"::', ['localhost', '', '']), + ('"local:host"', ['local:host']), + ('"local:host:"pass"', ['"local', 'host', "pass"]), + ('"local":"host"', ['local', 'host']), + ('"local":host"', ['local', 'host"']), + ('localhost:6379:1:pass"word:"my-app-namespace:dev"', + ['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']), + ] + for src, args in params: + self.assertEqual(args, parse_source_args(src)) class ReadOnlyTokenFileTestCase(unittest.TestCase): patch('os.path.isdir', MagicMock(return_value=False)) @@ -402,6 +424,15 @@ class TokenRedisTestCase(unittest.TestCase): self.assertEqual(plugin._password, None) self.assertEqual(plugin._namespace, "namespace:") + def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self): + plugin = TokenRedis('127.0.0.1::::"ns1:ns2"') + + self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._port, 6379) + self.assertEqual(plugin._db, 0) + self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "ns1:ns2:") + def test_src_with_host_empty_port_db_no_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1::2') diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index bb9abe5..ac4a88f 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -7,6 +7,24 @@ import json logger = logging.getLogger(__name__) +_SOURCE_SPLIT_REGEX = re.compile( + r'(?<=^)"([^"]+)"(?=:|$)' + r'|(?<=:)"([^"]+)"(?=:|$)' + r'|(?<=^)([^:]*)(?=:|$)' + r'|(?<=:)([^:]*)(?=:|$)', +) + + +def parse_source_args(src): + """It works like src.split(":") but with the ability to use a colon + if you wrap the word in quotation marks. + + a:b:c:d -> ['a', 'b', 'c', 'd' + a:"b:c":c -> ['a', 'b:c', 'd'] + """ + matches = _SOURCE_SPLIT_REGEX.findall(src) + return [m[0] or m[1] or m[2] or m[3] for m in matches] + class BasePlugin(): def __init__(self, src): @@ -197,6 +215,10 @@ class TokenRedis(BasePlugin): my-redis-host::::my-app-namespace + Or if your namespace is nested, you can wrap it in quotes: + + my-redis-host::::"first-ns:second-ns" + In the more general case you will use: my-redis-host:6380:1:verysecretpass:my-app-namespace @@ -241,7 +263,7 @@ class TokenRedis(BasePlugin): self._password = None self._namespace = "" try: - fields = src.split(":") + fields = parse_source_args(src) if len(fields) == 1: self._server = fields[0] elif len(fields) == 2: