This commit is contained in:
Pierre Ossman 2024-08-29 16:35:49 +02:00
commit 417210f2cf
2 changed files with 163 additions and 16 deletions

View File

@ -7,7 +7,29 @@ import unittest
from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk
from websockify.token_plugins import ReadOnlyTokenFile, JWTTokenApi, TokenRedis
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self):
params = [
('', ['']),
(':', ['', '']),
('::', ['', '', '']),
('"', ['"']),
('""', ['""']),
('"""', ['"""']),
('"localhost"', ['localhost']),
('"localhost":', ['localhost', '']),
('"localhost"::', ['localhost', '', '']),
('"local:host"', ['local:host']),
('"local:host:"pass"', ['"local', 'host', "pass"]),
('"local":"host"', ['local', 'host']),
('"local":host"', ['local', 'host"']),
('localhost:6379:1:pass"word:"my-app-namespace:dev"',
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']),
]
for src, args in params:
self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False))
@ -267,6 +289,42 @@ class TokenRedisTestCase(unittest.TestCase):
instance.get.assert_called_once_with('testhost')
self.assertIsNone(result)
@patch('redis.Redis')
def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
token = 'testhost'
def mock_redis_get(key):
self.assertEqual(key, token)
return b'remote_host:remote_port'
instance = mock_redis.return_value
instance.get = mock_redis_get
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
@patch('redis.Redis')
def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234:::namespace')
token = 'testhost'
def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token)
return b'remote_host:remote_port'
instance = mock_redis.return_value
instance.get = mock_redis_get
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
def test_src_only_host(self):
plugin = TokenRedis('127.0.0.1')
@ -274,6 +332,7 @@ class TokenRedisTestCase(unittest.TestCase):
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port(self):
plugin = TokenRedis('127.0.0.1:1234')
@ -282,6 +341,7 @@ class TokenRedisTestCase(unittest.TestCase):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db(self):
plugin = TokenRedis('127.0.0.1:1234:2')
@ -290,6 +350,7 @@ class TokenRedisTestCase(unittest.TestCase):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret')
@ -298,67 +359,112 @@ class TokenRedisTestCase(unittest.TestCase):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_pass(self):
def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::verysecret')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass(self):
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis('127.0.0.1::::')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_no_pass(self):
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_no_db_no_pass(self):
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_no_pass(self):
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis('127.0.0.1::::namespace')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:")
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_empty_db_pass(self):
def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:1234::verysecret')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_pass(self):
def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:verysecret')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_empty_pass(self):
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

View File

@ -7,6 +7,24 @@ import json
logger = logging.getLogger(__name__)
_SOURCE_SPLIT_REGEX = re.compile(
r'(?<=^)"([^"]+)"(?=:|$)'
r'|(?<=:)"([^"]+)"(?=:|$)'
r'|(?<=^)([^:]*)(?=:|$)'
r'|(?<=:)([^:]*)(?=:|$)',
)
def parse_source_args(src):
"""It works like src.split(":") but with the ability to use a colon
if you wrap the word in quotation marks.
a:b:c:d -> ['a', 'b', 'c', 'd'
a:"b:c":c -> ['a', 'b:c', 'd']
"""
matches = _SOURCE_SPLIT_REGEX.findall(src)
return [m[0] or m[1] or m[2] or m[3] for m in matches]
class BasePlugin():
def __init__(self, src):
@ -178,9 +196,9 @@ class TokenRedis(BasePlugin):
The token source is in the format:
host[:port[:db[:password]]]
host[:port[:db[:password[:namespace]]]]
where port, db and password are optional. If port or db are left empty
where port, db, password and namespace are optional. If port or db are left empty
they will take its default value, ie. 6379 and 0 respectively.
If your redis server is using the default port (6379) then you can use:
@ -192,9 +210,18 @@ class TokenRedis(BasePlugin):
my-redis-host:::verysecretpass
You can also specify a namespace. In this case, the tokens
will be stored in the format '{namespace}:{token}'
my-redis-host::::my-app-namespace
Or if your namespace is nested, you can wrap it in quotes:
my-redis-host::::"first-ns:second-ns"
In the more general case you will use:
my-redis-host:6380:1:verysecretpass
my-redis-host:6380:1:verysecretpass:my-app-namespace
The TokenRedis plugin expects the format of the target in one of these two
formats:
@ -234,8 +261,9 @@ class TokenRedis(BasePlugin):
self._port = 6379
self._db = 0
self._password = None
self._namespace = ""
try:
fields = src.split(":")
fields = parse_source_args(src)
if len(fields) == 1:
self._server = fields[0]
elif len(fields) == 2:
@ -256,15 +284,28 @@ class TokenRedis(BasePlugin):
self._db = 0
if not self._password:
self._password = None
elif len(fields) == 5:
self._server, self._port, self._db, self._password, self._namespace = fields
if not self._port:
self._port = 6379
if not self._db:
self._db = 0
if not self._password:
self._password = None
if not self._namespace:
self._namespace = ""
else:
raise ValueError
self._port = int(self._port)
self._db = int(self._db)
logger.info("TokenRedis backend initilized (%s:%s)" %
if self._namespace:
self._namespace += ":"
logger.info("TokenRedis backend initialized (%s:%s)" %
(self._server, self._port))
except ValueError:
logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>]]]" %
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
src)
sys.exit()
@ -278,7 +319,7 @@ class TokenRedis(BasePlugin):
logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port,
db=self._db, password=self._password)
stuff = client.get(token)
stuff = client.get(self._namespace + token)
if stuff is None:
return None
else: