This commit is contained in:
Alexander Zeijlon 2025-06-23 15:55:38 +02:00
commit c44e432e14
2 changed files with 53 additions and 48 deletions

View File

@ -4,7 +4,7 @@
import sys
import unittest
from unittest.mock import patch, mock_open, MagicMock
from unittest.mock import patch, MagicMock
from jwcrypto import jwt, jwk
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
@ -32,49 +32,50 @@ class ParseSourceArgumentsTestCase(unittest.TestCase):
self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False))
def test_empty(self):
mock_source_file = MagicMock()
mock_source_file.is_dir.return_value = False
mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = [""]
with patch("websockify.token_plugins.Path") as mock_path:
mock_path.return_value = mock_source_file
plugin = ReadOnlyTokenFile('configfile')
config = ""
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
pyopen.assert_called_once_with('configfile')
mock_path.assert_called_once_with('configfile')
self.assertIsNone(result)
patch('os.path.isdir', MagicMock(return_value=False))
def test_simple(self):
mock_source_file = MagicMock()
mock_source_file.is_dir.return_value = False
mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = ["testhost: remote_host:remote_port"]
with patch("websockify.token_plugins.Path") as mock_path:
mock_path.return_value = mock_source_file
plugin = ReadOnlyTokenFile('configfile')
config = "testhost: remote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
pyopen.assert_called_once_with('configfile')
mock_path.assert_called_once_with('configfile')
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
patch('os.path.isdir', MagicMock(return_value=False))
def test_tabs(self):
mock_source_file = MagicMock()
mock_source_file.is_dir.return_value = False
mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = ["testhost:\tremote_host:remote_port"]
with patch("websockify.token_plugins.Path") as mock_path:
mock_path.return_value = mock_source_file
plugin = ReadOnlyTokenFile('configfile')
config = "testhost:\tremote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
pyopen.assert_called_once_with('configfile')
mock_path.assert_called_once_with('configfile')
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class JWSTokenTestCase(unittest.TestCase):
def test_asymmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/public.pem")

View File

@ -1,9 +1,9 @@
import logging
import os
import sys
import time
import re
import json
from pathlib import Path
logger = logging.getLogger(__name__)
@ -43,16 +43,17 @@ class ReadOnlyTokenFile(BasePlugin):
self._targets = None
def _load_targets(self):
if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for
f in os.listdir(self.source)]
source = Path(self.source)
if source.is_dir():
cfg_files = [file for file in source if file.is_file()]
else:
cfg_files = [self.source]
cfg_files = [source]
self._targets = {}
index = 1
for f in cfg_files:
for line in [l.strip() for l in open(f).readlines()]:
with f.open() as file:
for line in file.readlines():
if line and not line.startswith('#'):
try:
tok, target = re.split(r':\s', line)
@ -89,14 +90,16 @@ class TokenFileName(BasePlugin):
# contents of file is host:port
def __init__(self, src):
super().__init__(src)
if not os.path.isdir(src):
if not Path(src).is_dir():
raise Exception("TokenFileName plugin requires a directory")
def lookup(self, token):
token = os.path.basename(token)
path = os.path.join(self.source, token)
if os.path.exists(path):
return open(path).read().strip().split(':')
token = Path(token).name
path = Path(self.source) / token
if path.exists():
with path.open() as f:
text = f.read().strip().split(':')
return text
else:
return None
@ -349,23 +352,24 @@ class TokenRedis(BasePlugin):
class UnixDomainSocketDirectory(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._dir_path = os.path.abspath(self.source)
self._dir_path = Path(self.source).absolute()
def lookup(self, token):
def lookup(self, token: str):
try:
import stat
if not os.path.isdir(self._dir_path):
if not self._dir_path.is_dir():
return None
uds_path = os.path.abspath(os.path.join(self._dir_path, token))
if not uds_path.startswith(self._dir_path):
uds_path = (self._dir_path / token).absolute()
if not str(uds_path).startswith(str(self._dir_path)):
return None
if not os.path.exists(uds_path):
if not uds_path.exists():
return None
if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
if not stat.S_ISSOCK(uds_path.stat().st_mode):
return None
return [ 'unix_socket', uds_path ]