diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index a9fd256..e1b967b 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -4,7 +4,7 @@ import sys import unittest -from unittest.mock import patch, mock_open, MagicMock +from unittest.mock import patch, MagicMock from jwcrypto import jwt, jwk from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis @@ -32,49 +32,50 @@ class ParseSourceArgumentsTestCase(unittest.TestCase): self.assertEqual(args, parse_source_args(src)) class ReadOnlyTokenFileTestCase(unittest.TestCase): - patch('os.path.isdir', MagicMock(return_value=False)) def test_empty(self): - plugin = ReadOnlyTokenFile('configfile') + mock_source_file = MagicMock() + mock_source_file.is_dir.return_value = False + mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = [""] - config = "" - pyopen = mock_open(read_data=config) - - with patch("websockify.token_plugins.open", pyopen, create=True): + with patch("websockify.token_plugins.Path") as mock_path: + mock_path.return_value = mock_source_file + plugin = ReadOnlyTokenFile('configfile') result = plugin.lookup('testhost') - pyopen.assert_called_once_with('configfile') + mock_path.assert_called_once_with('configfile') self.assertIsNone(result) - patch('os.path.isdir', MagicMock(return_value=False)) def test_simple(self): - plugin = ReadOnlyTokenFile('configfile') + mock_source_file = MagicMock() + mock_source_file.is_dir.return_value = False + mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = ["testhost: remote_host:remote_port"] - config = "testhost: remote_host:remote_port" - pyopen = mock_open(read_data=config) - - with patch("websockify.token_plugins.open", pyopen, create=True): + with patch("websockify.token_plugins.Path") as mock_path: + mock_path.return_value = mock_source_file + plugin = ReadOnlyTokenFile('configfile') result = plugin.lookup('testhost') - pyopen.assert_called_once_with('configfile') + mock_path.assert_called_once_with('configfile') self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") - patch('os.path.isdir', MagicMock(return_value=False)) def test_tabs(self): - plugin = ReadOnlyTokenFile('configfile') + mock_source_file = MagicMock() + mock_source_file.is_dir.return_value = False + mock_source_file.open.return_value.__enter__.return_value.readlines.return_value = ["testhost:\tremote_host:remote_port"] - config = "testhost:\tremote_host:remote_port" - pyopen = mock_open(read_data=config) - - with patch("websockify.token_plugins.open", pyopen, create=True): + with patch("websockify.token_plugins.Path") as mock_path: + mock_path.return_value = mock_source_file + plugin = ReadOnlyTokenFile('configfile') result = plugin.lookup('testhost') - pyopen.assert_called_once_with('configfile') + mock_path.assert_called_once_with('configfile') self.assertIsNotNone(result) self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") + class JWSTokenTestCase(unittest.TestCase): def test_asymmetric_jws_token_plugin(self): plugin = JWTTokenApi("./tests/fixtures/public.pem") diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index 5a95490..d582032 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -1,9 +1,9 @@ import logging -import os import sys import time import re import json +from pathlib import Path logger = logging.getLogger(__name__) @@ -43,23 +43,24 @@ class ReadOnlyTokenFile(BasePlugin): self._targets = None def _load_targets(self): - if os.path.isdir(self.source): - cfg_files = [os.path.join(self.source, f) for - f in os.listdir(self.source)] + source = Path(self.source) + if source.is_dir(): + cfg_files = [file for file in source if file.is_file()] else: - cfg_files = [self.source] + cfg_files = [source] self._targets = {} index = 1 for f in cfg_files: - for line in [l.strip() for l in open(f).readlines()]: - if line and not line.startswith('#'): - try: - tok, target = re.split(r':\s', line) - self._targets[tok] = target.strip().rsplit(':', 1) - except ValueError: - logger.error("Syntax error in %s on line %d" % (self.source, index)) - index += 1 + with f.open() as file: + for line in file.readlines(): + if line and not line.startswith('#'): + try: + tok, target = re.split(r':\s', line) + self._targets[tok] = target.strip().rsplit(':', 1) + except ValueError: + logger.error("Syntax error in %s on line %d" % (self.source, index)) + index += 1 def lookup(self, token): if self._targets is None: @@ -89,14 +90,16 @@ class TokenFileName(BasePlugin): # contents of file is host:port def __init__(self, src): super().__init__(src) - if not os.path.isdir(src): + if not Path(src).is_dir(): raise Exception("TokenFileName plugin requires a directory") - + def lookup(self, token): - token = os.path.basename(token) - path = os.path.join(self.source, token) - if os.path.exists(path): - return open(path).read().strip().split(':') + token = Path(token).name + path = Path(self.source) / token + if path.exists(): + with path.open() as f: + text = f.read().strip().split(':') + return text else: return None @@ -349,23 +352,24 @@ class TokenRedis(BasePlugin): class UnixDomainSocketDirectory(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._dir_path = os.path.abspath(self.source) + self._dir_path = Path(self.source).absolute() - def lookup(self, token): + def lookup(self, token: str): try: import stat - if not os.path.isdir(self._dir_path): + if not self._dir_path.is_dir(): return None - uds_path = os.path.abspath(os.path.join(self._dir_path, token)) - if not uds_path.startswith(self._dir_path): + uds_path = (self._dir_path / token).absolute() + + if not str(uds_path).startswith(str(self._dir_path)): return None - if not os.path.exists(uds_path): + if not uds_path.exists(): return None - if not stat.S_ISSOCK(os.stat(uds_path).st_mode): + if not stat.S_ISSOCK(uds_path.stat().st_mode): return None return [ 'unix_socket', uds_path ]