run `ruff format .`

This commit is contained in:
Doctor 2025-03-12 20:55:57 +03:00
parent 4fc7e52352
commit 3eca8ad195
20 changed files with 1488 additions and 1027 deletions

View File

@ -1,11 +1,11 @@
from setuptools import setup, find_packages
version = '0.13.0'
name = 'websockify'
long_description = open("README.md").read() + "\n" + \
open("CHANGES.txt").read() + "\n"
version = "0.13.0"
name = "websockify"
long_description = open("README.md").read() + "\n" + open("CHANGES.txt").read() + "\n"
setup(name=name,
setup(
name=name,
version=version,
description="Websockify.",
long_description=long_description,
@ -22,23 +22,23 @@ setup(name=name,
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
keywords='noVNC websockify',
license='LGPLv3',
keywords="noVNC websockify",
license="LGPLv3",
url="https://github.com/novnc/websockify",
author="Joel Martin",
author_email="github@martintribe.org",
packages=['websockify'],
packages=["websockify"],
include_package_data=True,
install_requires=[
'numpy', 'requests',
'jwcrypto',
'redis',
"numpy",
"requests",
"jwcrypto",
"redis",
],
zip_safe=False,
entry_points={
'console_scripts': [
'websockify = websockify.websocketproxy:websockify_init',
"console_scripts": [
"websockify = websockify.websocketproxy:websockify_init",
]
},
)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
A WebSocket server that echos back whatever it receives from the client.
Copyright 2010 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -8,16 +8,19 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import os, sys, select, optparse, logging
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketEcho(WebSockifyRequestHandler):
"""
WebSockets server that echos back whatever is received from the
client."""
buffer_size = 8096
def new_websocket_client(self):
@ -33,9 +36,11 @@ class WebSocketEcho(WebSockifyRequestHandler):
while True:
wlist = []
if cqueue or c_pend: wlist.append(self.request)
if cqueue or c_pend:
wlist.append(self.request)
ins, outs, excepts = select.select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if self.request in outs:
# Send queued target data to the client
@ -50,20 +55,27 @@ class WebSocketEcho(WebSockifyRequestHandler):
if closed:
break
if __name__ == '__main__':
if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
parser.add_option(
"--verbose",
"-v",
action="store_true",
help="verbose messages and per frame traffic",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args()
try:
if len(args) != 1: raise ValueError
if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0])
except ValueError:
parser.error("Invalid arguments")
@ -73,4 +85,3 @@ if __name__ == '__main__':
opts.web = "."
server = WebSockifyServer(WebSocketEcho, **opts.__dict__)
server.start_server()

View File

@ -6,8 +6,11 @@ import optparse
import select
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocket, \
WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocket import (
WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
parser = optparse.OptionParser(usage="%prog URL")
(opts, args) = parser.parse_args()
@ -22,19 +25,23 @@ print("Connecting to %s..." % URL)
sock.connect(URL)
print("Connected.")
def send(msg):
while True:
try:
sock.sendmsg(msg)
break
except WebSocketWantReadError:
msg = ''
msg = ""
ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError:
msg = ''
msg = ""
ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
def read():
while True:
@ -42,10 +49,13 @@ def read():
return sock.recvmsg()
except WebSocketWantReadError:
ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError:
ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
counter = 1
while True:
@ -56,7 +66,8 @@ while True:
while True:
ins, outs, excepts = select.select([sock], [], [], 1.0)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if ins == []:
break

View File

@ -1,32 +1,32 @@
#!/usr/bin/env python
'''
"""
WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted.
'''
"""
import sys, os, select, random, time, optparse, logging
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketLoadServer(WebSockifyServer):
class WebSocketLoadServer(WebSockifyServer):
recv_cnt = 0
send_cnt = 0
def __init__(self, *args, **kwargs):
self.delay = kwargs.pop('delay')
self.delay = kwargs.pop("delay")
WebSockifyServer.__init__(self, *args, **kwargs)
class WebSocketLoad(WebSockifyRequestHandler):
max_packet_size = 10000
def new_websocket_client(self):
print "Prepopulating random array"
print("Prepopulating random array")
self.rand_array = []
for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9))
@ -37,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
self.responder(self.request)
print "accumulated errors:", self.errors
print("accumulated errors:", self.errors)
self.errors = 0
def responder(self, client):
@ -49,7 +49,8 @@ class WebSocketLoad(WebSockifyRequestHandler):
while True:
ins, outs, excepts = select.select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if client in ins:
frames, closed = self.recv_frames()
@ -57,7 +58,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
err = self.check(frames)
if err:
self.errors = self.errors + 1
print err
print(err)
if closed:
break
@ -85,12 +86,10 @@ class WebSocketLoad(WebSockifyRequestHandler):
return data
def check(self, frames):
err = ""
for data in frames:
if data.count('$') > 1:
if data.count("$") > 1:
raise Exception("Multiple parts within single packet")
if len(data) == 0:
self.traffic("_")
@ -101,12 +100,12 @@ class WebSocketLoad(WebSockifyRequestHandler):
continue
try:
cnt, length, chksum, nums = data[1:-1].split(':')
cnt, length, chksum, nums = data[1:-1].split(":")
cnt = int(cnt)
length = int(length)
chksum = int(chksum)
except ValueError:
print "\n<BOF>" + repr(data) + "<EOF>"
print("\n<BOF>" + repr(data) + "<EOF>")
err += "Invalid data format\n"
continue
@ -131,27 +130,37 @@ class WebSocketLoad(WebSockifyRequestHandler):
real_chksum += int(num)
if real_chksum != chksum:
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
err += "Expected checksum %d but real chksum is %d\n" % (
chksum,
real_chksum,
)
return err
if __name__ == '__main__':
if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
parser.add_option(
"--verbose",
"-v",
action="store_true",
help="verbose messages and per frame traffic",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args()
try:
if len(args) != 1: raise ValueError
if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0])
if len(args) not in [1,2]: raise ValueError
if len(args) not in [1, 2]:
raise ValueError
opts.listen_port = int(args[0])
if len(args) == 2:
opts.delay = int(args[1])
@ -165,4 +174,3 @@ if __name__ == '__main__':
opts.web = "."
server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
server.start_server()

View File

@ -7,22 +7,27 @@ import unittest
class BasicHTTPAuthTestCase(unittest.TestCase):
def setUp(self):
self.plugin = BasicHTTPAuth('Aladdin:open sesame')
self.plugin = BasicHTTPAuth("Aladdin:open sesame")
def test_no_auth(self):
headers = {}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_invalid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
headers = {"Authorization": "Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0"}
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_valid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='}
self.plugin.authenticate(headers, 'localhost', '1234')
headers = {"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="}
self.plugin.authenticate(headers, "localhost", "1234")
def test_garbage_auth(self):
headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
headers = {"Authorization": "Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)

View File

@ -7,74 +7,87 @@ import unittest
from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
from websockify.token_plugins import (
parse_source_args,
ReadOnlyTokenFile,
JWTTokenApi,
TokenRedis,
)
class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self):
params = [
('', ['']),
(':', ['', '']),
('::', ['', '', '']),
("", [""]),
(":", ["", ""]),
("::", ["", "", ""]),
('"', ['"']),
('""', ['""']),
('"""', ['"""']),
('"localhost"', ['localhost']),
('"localhost":', ['localhost', '']),
('"localhost"::', ['localhost', '', '']),
('"local:host"', ['local:host']),
('"local:host:"pass"', ['"local', 'host', "pass"]),
('"local":"host"', ['local', 'host']),
('"local":host"', ['local', 'host"']),
('localhost:6379:1:pass"word:"my-app-namespace:dev"',
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']),
('"localhost"', ["localhost"]),
('"localhost":', ["localhost", ""]),
('"localhost"::', ["localhost", "", ""]),
('"local:host"', ["local:host"]),
('"local:host:"pass"', ['"local', "host", "pass"]),
('"local":"host"', ["local", "host"]),
('"local":host"', ["local", 'host"']),
(
'localhost:6379:1:pass"word:"my-app-namespace:dev"',
["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"],
),
]
for src, args in params:
self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_empty(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = ""
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNone(result)
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_simple(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = "testhost: remote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_tabs(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = "testhost:\tremote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class JWSTokenTestCase(unittest.TestCase):
def test_asymmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
@ -82,7 +95,9 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -97,21 +112,26 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
@patch('time.time')
@patch("time.time")
def test_jwt_valid_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 150
@ -121,14 +141,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('time.time')
@patch("time.time")
def test_jwt_early_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 50
@ -136,14 +159,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertIsNone(result)
@patch('time.time')
@patch("time.time")
def test_jwt_late_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 250
@ -157,7 +183,9 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -172,7 +200,9 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -188,10 +218,14 @@ class JWSTokenTestCase(unittest.TestCase):
public_key_data = open("./tests/fixtures/public.pem", "rb").read()
private_key.import_from_pem(private_key_data)
public_key.import_from_pem(public_key_data)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(private_key)
jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize())
jwe_token = jwt.JWT(
header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize(),
)
jwe_token.make_encrypted_token(public_key)
result = plugin.lookup(jwt_token.serialize())
@ -200,103 +234,104 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class TokenRedisTestCase(unittest.TestCase):
def setUp(self):
try:
import redis
except ImportError:
patcher = patch.dict(sys.modules, {'redis': MagicMock()})
patcher = patch.dict(sys.modules, {"redis": MagicMock()})
patcher.start()
self.addCleanup(patcher.stop)
@patch('redis.Redis')
@patch("redis.Redis")
def test_empty(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = None
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch('redis.Redis')
@patch("redis.Redis")
def test_simple(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port"}'
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_json_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b' {"host": "remote_host:remote_port"} '
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_text_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'remote_host:remote_port'
instance.get.return_value = b"remote_host:remote_port"
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_text_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b' remote_host:remote_port '
instance.get.return_value = b" remote_host:remote_port "
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_invalid_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port" '
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch('redis.Redis')
@patch("redis.Redis")
def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
token = 'testhost'
plugin = TokenRedis("127.0.0.1:1234")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, token)
return b'remote_host:remote_port'
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
@ -304,17 +339,17 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234:::namespace')
token = 'testhost'
plugin = TokenRedis("127.0.0.1:1234:::namespace")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token)
return b'remote_host:remote_port'
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
@ -322,103 +357,103 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
def test_src_only_host(self):
plugin = TokenRedis('127.0.0.1')
plugin = TokenRedis("127.0.0.1")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port(self):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db(self):
plugin = TokenRedis('127.0.0.1:1234:2')
plugin = TokenRedis("127.0.0.1:1234:2")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret')
plugin = TokenRedis("127.0.0.1:1234:2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace')
plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::verysecret')
plugin = TokenRedis("127.0.0.1:::verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis('127.0.0.1::::')
plugin = TokenRedis("127.0.0.1::::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::')
plugin = TokenRedis("127.0.0.1:::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::')
plugin = TokenRedis("127.0.0.1::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:')
plugin = TokenRedis("127.0.0.1:")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis('127.0.0.1::::namespace')
plugin = TokenRedis("127.0.0.1::::namespace")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
@ -427,43 +462,43 @@ class TokenRedisTestCase(unittest.TestCase):
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:")
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2')
plugin = TokenRedis("127.0.0.1::2")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:1234::verysecret')
plugin = TokenRedis("127.0.0.1:1234::verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:verysecret')
plugin = TokenRedis("127.0.0.1::2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:')
plugin = TokenRedis("127.0.0.1::2:")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)

View File

@ -15,198 +15,267 @@
# under the License.
"""Unit tests for websocket"""
import unittest
from websockify import websocket
class FakeSocket:
def __init__(self):
self.data = b''
self.data = b""
def send(self, buf):
self.data += buf
return len(buf)
class AcceptTestCase(unittest.TestCase):
def test_success(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data)
self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data)
self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nUpgrade: websocket\r\n" in sock.data)
self.assertTrue(b"\r\nConnection: Upgrade\r\n" in sock.data)
self.assertTrue(
b"\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n" in sock.data
)
def test_bad_version(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '5',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '20',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(
Exception,
ws.accept,
sock,
{"upgrade": "websocket", "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q=="},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "5",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "20",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_bad_upgrade(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket2',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket2",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_missing_key(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13'})
self.assertRaises(
Exception,
ws.accept,
sock,
{"upgrade": "websocket", "Sec-WebSocket-Version": "13"},
)
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'gazonk'
return "gazonk"
ws = ProtoSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nSec-WebSocket-Protocol: gazonk\r\n" in sock.data)
def test_no_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertFalse(b"\r\nSec-WebSocket-Protocol:" in sock.data)
def test_missing_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'oddball'
return "oddball"
ws = ProtoSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
class PingPongTest(unittest.TestCase):
def setUp(self):
self.ws = websocket.WebSocket()
self.sock = FakeSocket()
self.ws.accept(self.sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ')
self.sock.data = b''
self.ws.accept(
self.sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(self.sock.data[:13], b"HTTP/1.1 101 ")
self.sock.data = b""
def test_ping(self):
self.ws.ping()
self.assertEqual(self.sock.data, b'\x89\x00')
self.assertEqual(self.sock.data, b"\x89\x00")
def test_pong(self):
self.ws.pong()
self.assertEqual(self.sock.data, b'\x8a\x00')
self.assertEqual(self.sock.data, b"\x8a\x00")
def test_ping_data(self):
self.ws.ping(b'foo')
self.assertEqual(self.sock.data, b'\x89\x03foo')
self.ws.ping(b"foo")
self.assertEqual(self.sock.data, b"\x89\x03foo")
def test_pong_data(self):
self.ws.pong(b'foo')
self.assertEqual(self.sock.data, b'\x8a\x03foo')
self.ws.pong(b"foo")
self.assertEqual(self.sock.data, b"\x8a\x03foo")
class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
buf = b"\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58"
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x1)
self.assertEqual(res['masked'], True)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'Hello')
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x1)
self.assertEqual(res["masked"], True)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], b"Hello")
def test_decode_hybi_binary(self):
buf = b'\x82\x04\x01\x02\x03\x04'
buf = b"\x82\x04\x01\x02\x03\x04"
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], b"\x01\x02\x03\x04")
def test_decode_hybi_extended_16bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7e\x01\x04' + data
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b"\x82\x7e\x01\x04" + data
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], data)
def test_decode_hybi_extended_64bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b"\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04" + data
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], data)
def test_decode_hybi_multi(self):
buf1 = b'\x01\x03\x48\x65\x6c'
buf2 = b'\x80\x02\x6c\x6f'
buf1 = b"\x01\x03\x48\x65\x6c"
buf2 = b"\x80\x02\x6c\x6f"
ws = websocket.WebSocket()
res1 = ws._decode_hybi(buf1)
self.assertEqual(res1['fin'], 0)
self.assertEqual(res1['opcode'], 0x1)
self.assertEqual(res1['length'], len(buf1))
self.assertEqual(res1['payload'], b'Hel')
self.assertEqual(res1["fin"], 0)
self.assertEqual(res1["opcode"], 0x1)
self.assertEqual(res1["length"], len(buf1))
self.assertEqual(res1["payload"], b"Hel")
res2 = ws._decode_hybi(buf2)
self.assertEqual(res2['fin'], 1)
self.assertEqual(res2['opcode'], 0x0)
self.assertEqual(res2['length'], len(buf2))
self.assertEqual(res2['payload'], b'lo')
self.assertEqual(res2["fin"], 1)
self.assertEqual(res2["opcode"], 0x0)
self.assertEqual(res2["length"], len(buf2))
self.assertEqual(res2["payload"], b"lo")
def test_encode_hybi_basic(self):
ws = websocket.WebSocket()
res = ws._encode_hybi(0x1, b'Hello')
expected = b'\x81\x05\x48\x65\x6c\x6c\x6f'
res = ws._encode_hybi(0x1, b"Hello")
expected = b"\x81\x05\x48\x65\x6c\x6c\x6f"
self.assertEqual(res, expected)

View File

@ -30,7 +30,7 @@ from websockify import auth_plugins
class FakeSocket:
def __init__(self, data=b''):
def __init__(self, data=b""):
self._data = data
def recv(self, amt, flags=None):
@ -40,11 +40,11 @@ class FakeSocket:
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
def makefile(self, mode="r", buffsize=None):
if "b" in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
return StringIO(self._data.decode("latin_1"))
class FakeServer:
@ -58,14 +58,16 @@ class FakeServer:
self.ssl_target = None
self.unix_target = None
class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.handler = websocketproxy.ProxyRequestHandler(
FakeSocket(), "127.0.0.1", FakeServer())
FakeSocket(), "127.0.0.1", FakeServer()
)
self.handler.path = "https://localhost:6080/websockify?token=blah"
self.handler.headers = {}
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
def tearDown(self):
patch.stopall()
@ -76,8 +78,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token):
return ("some host", "some port")
host, port = self.handler.get_target(
TestPlugin(None))
host, port = self.handler.get_target(TestPlugin(None))
self.assertEqual(host, "some host")
self.assertEqual(port, "some port")
@ -87,8 +88,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token):
return ("unix_socket", "/tmp/socket")
_, socket = self.handler.get_target(
TestPlugin(None))
_, socket = self.handler.get_target(TestPlugin(None))
self.assertEqual(socket, "/tmp/socket")
@ -100,11 +100,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
with self.assertRaises(FakeServer.EClose):
self.handler.get_target(TestPlugin(None))
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_token_plugin(self):
class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token):
return (self.source + token).split(',')
return (self.source + token).split(",")
self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.validate_connection()
@ -112,7 +112,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
@ -128,4 +128,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.server.target_host = "someotherhost"
self.handler.auth_connection()

View File

@ -1,5 +1,5 @@
"""Unit tests for websocketserver"""
import unittest
from unittest.mock import patch, MagicMock
@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase):
# Then
req_obj.end_headers.assert_called_once_with()

View File

@ -15,6 +15,7 @@
# under the License.
"""Unit tests for websockifyserver"""
import errno
import os
import logging
@ -36,11 +37,11 @@ from websockify import websockifyserver
def raise_oserror(*args, **kwargs):
raise OSError('fake error')
raise OSError("fake error")
class FakeSocket:
def __init__(self, data=b''):
def __init__(self, data=b""):
self._data = data
def recv(self, amt, flags=None):
@ -50,19 +51,19 @@ class FakeSocket:
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
def makefile(self, mode="r", buffsize=None):
if "b" in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
return StringIO(self._data.decode("latin_1"))
class WebSockifyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up
patch('os.chdir').start()
patch("os.chdir").start()
def tearDown(self):
"""Called automatically after each test."""
@ -70,31 +71,41 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir)
super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
web = kwargs.pop('web', self.tmpdir)
def _get_server(
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
):
web = kwargs.pop("web", self.tmpdir)
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=web,
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
**kwargs)
handler_class,
listen_host="localhost",
listen_port=80,
key=self.tmpdir,
web=web,
record=self.tmpdir,
daemon=False,
ssl_only=0,
idle_timeout=1,
**kwargs,
)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_normal_get_with_only_upgrade_returns_error(self, send_error):
server = self._get_server(web=None)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
FakeSocket(b"GET /tmp.txt HTTP/1.1"), "127.0.0.1", server
)
handler.do_GET()
send_error.assert_called_with(405)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_list_dir_with_file_only_returns_error(self, send_error):
server = self._get_server(file_only=True)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server)
FakeSocket(b"GET / HTTP/1.1"), "127.0.0.1", server
)
handler.path = '/'
handler.path = "/"
handler.do_GET()
send_error.assert_called_with(404)
@ -102,9 +113,9 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
class WebSockifyServerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up
patch('os.chdir').start()
patch("os.chdir").start()
def tearDown(self):
"""Called automatically after each test."""
@ -112,32 +123,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir)
super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
def _get_server(
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
):
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=self.tmpdir,
record=self.tmpdir, **kwargs)
handler_class,
listen_host="localhost",
listen_port=80,
key=self.tmpdir,
web=self.tmpdir,
record=self.tmpdir,
**kwargs,
)
def test_daemonize_raises_error_while_closing_fds(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
patch("os.fork").start().return_value = 0
patch("signal.signal").start()
patch("os.setsid").start()
patch("os.close").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
def raise_oserror_ebadf(fd):
raise OSError(errno.EBADF, 'fake error')
raise OSError(errno.EBADF, "fake error")
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror_ebadf
patch('os.open').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
patch("os.fork").start().return_value = 0
patch("signal.signal").start()
patch("os.setsid").start()
patch("os.close").start().side_effect = raise_oserror_ebadf
patch("os.open").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_handshake_fails_on_not_ready(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -145,23 +162,29 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_select(rlist, wlist, xlist, timeout=None):
return ([], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
FakeSocket(), '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
FakeSocket(),
"127.0.0.1",
)
def test_empty_handshake_fails(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket('')
sock = FakeSocket("")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_handshake_policy_request(self):
# TODO(directxman12): implement
@ -170,35 +193,39 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_handshake_ssl_only_without_ssl_raises_error(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
sock = FakeSocket(b'some initial data')
sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_no_ssl(self):
class FakeHandler:
CALLED = False
def __init__(self, *args, **kwargs):
type(self).CALLED = True
FakeHandler.CALLED = False
server = self._get_server(
handler_class=FakeHandler, daemon=True,
ssl_only=0, idle_timeout=1)
handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1
)
sock = FakeSocket(b'some initial data')
sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
patch("select.select").start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, "127.0.0.1"), sock)
self.assertTrue(FakeHandler.CALLED, True)
def test_do_handshake_ssl(self):
@ -210,18 +237,22 @@ class WebSockifyServerTestCase(unittest.TestCase):
pass
def test_do_handshake_ssl_without_cert_raises_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
cert='afdsfasdafdsafdsafdsafdas')
server = self._get_server(
daemon=True, ssl_only=0, idle_timeout=1, cert="afdsfasdafdsafdsafdsafdas"
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_error_eof_raises_close_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -234,58 +265,79 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_wrap_socket(*args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
class fake_create_default_context():
class fake_create_default_context:
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_sets_ciphers(self):
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
test_ciphers = "TEST-CIPHERS-1:TEST-CIPHER-2"
class FakeHandler:
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_ciphers=test_ciphers)
server = self._get_server(
handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_ciphers=test_ciphers,
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
class fake_create_default_context():
CIPHERS = ''
class fake_create_default_context:
CIPHERS = ""
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def set_ciphers(self, ciphers_to_set):
fake_create_default_context.CIPHERS = ciphers_to_set
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
def test_do_handshake_ssl_sets_opions(self):
@ -295,8 +347,12 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_options=test_options)
server = self._get_server(
handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_options=test_options,
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
@ -304,26 +360,36 @@ class WebSockifyServerTestCase(unittest.TestCase):
class fake_create_default_context:
OPTIONS = 0
def __init__(self, purpose):
self.verify_mode = None
self._options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def get_options(self):
return self._options
def set_options(self, val):
fake_create_default_context.OPTIONS = val
options = property(get_options, set_options)
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.OPTIONS, test_options)
def test_fallback_sigchld_handler(self):
@ -332,38 +398,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error")
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_start_server_keyboardinterrupt(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_start_server_systemexit(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit()
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_socket_set_keepalive_options(self):
@ -372,29 +438,37 @@ class WebSockifyServerTestCase(unittest.TestCase):
keepintvl = 56
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost',
sock = server.socket(
"localhost",
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
if hasattr(socket, "TCP_KEEPCNT"):
self.assertEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
)
self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)
sock = server.socket('localhost',
sock = server.socket(
"localhost",
tcp_keepalive=False,
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
if hasattr(socket, "TCP_KEEPCNT"):
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
)
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle
)
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)

View File

@ -1,4 +1,4 @@
import websockify
if __name__ == '__main__':
if __name__ == "__main__":
websockify.websocketproxy.websockify_init()

View File

@ -1,4 +1,4 @@
class BasePlugin():
class BasePlugin:
def __init__(self, src=None):
self.source = src
@ -7,7 +7,9 @@ class BasePlugin():
class AuthenticationError(Exception):
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
def __init__(
self, log_msg=None, response_code=403, response_headers={}, response_msg=None
):
self.code = response_code
self.headers = response_headers
self.msg = response_msg
@ -15,7 +17,7 @@ class AuthenticationError(Exception):
if log_msg is None:
log_msg = response_msg
super().__init__('%s %s' % (self.code, log_msg))
super().__init__("%s %s" % (self.code, log_msg))
class InvalidOriginError(AuthenticationError):
@ -24,12 +26,13 @@ class InvalidOriginError(AuthenticationError):
self.actual_origin = actual
super().__init__(
response_msg='Invalid Origin',
response_msg="Invalid Origin",
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
"%s, got '%s'" % (expected, actual),
)
class BasicHTTPAuth():
class BasicHTTPAuth:
"""Verifies Basic Auth headers. Specify src as username:password"""
def __init__(self, src=None):
@ -37,9 +40,10 @@ class BasicHTTPAuth():
def authenticate(self, headers, target_host, target_port):
import base64
auth_header = headers.get('Authorization')
auth_header = headers.get("Authorization")
if auth_header:
if not auth_header.startswith('Basic '):
if not auth_header.startswith("Basic "):
self.auth_error()
try:
@ -49,11 +53,11 @@ class BasicHTTPAuth():
try:
# http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication
user_pass_as_text = user_pass_raw.decode('ISO-8859-1')
user_pass_as_text = user_pass_raw.decode("ISO-8859-1")
except UnicodeDecodeError:
self.auth_error()
user_pass = user_pass_as_text.split(':', 1)
user_pass = user_pass_as_text.split(":", 1)
if len(user_pass) != 2:
self.auth_error()
@ -64,7 +68,7 @@ class BasicHTTPAuth():
self.demand_auth()
def validate_creds(self, username, password):
if '%s:%s' % (username, password) == self.src:
if "%s:%s" % (username, password) == self.src:
return True
else:
return False
@ -73,10 +77,13 @@ class BasicHTTPAuth():
raise AuthenticationError(response_code=403)
def demand_auth(self):
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
raise AuthenticationError(
response_code=401,
response_headers={"WWW-Authenticate": 'Basic realm="Websockify"'},
)
class ExpectOrigin():
class ExpectOrigin:
def __init__(self, src=None):
if src is None:
self.source = []
@ -84,11 +91,12 @@ class ExpectOrigin():
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
origin = headers.get('Origin', None)
origin = headers.get("Origin", None)
if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin)
class ClientCertCNAuth():
class ClientCertCNAuth:
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""
def __init__(self, src=None):
@ -98,5 +106,5 @@ class ClientCertCNAuth():
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source:
if headers.get("SSL_CLIENT_S_DN_CN", None) not in self.source:
raise AuthenticationError(response_code=403)

View File

@ -7,23 +7,26 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
as defined by RFC 5424.
"""
_legacy_head_fmt = '<{pri}>{ident}[{pid}]: '
_rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - '
_legacy_head_fmt = "<{pri}>{ident}[{pid}]: "
_rfc5424_head_fmt = "<{pri}>1 {timestamp} {hostname} {ident} {pid} - - "
_head_fmt = _rfc5424_head_fmt
_legacy = False
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ'
_timestamp_fmt = "%Y-%m-%dT%H:%M:%SZ"
_max_hostname = 255
_max_ident = 24 # safer for old daemons
_send_length = False
_tail = '\n'
_tail = "\n"
ident = None
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT),
def __init__(
self,
address=("localhost", handlers.SYSLOG_UDP_PORT),
facility=handlers.SysLogHandler.LOG_USER,
socktype=None, ident=None, legacy=False):
socktype=None,
ident=None,
legacy=False,
):
"""
Initialize a handler.
@ -46,7 +49,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
super().__init__(address, facility, socktype)
def emit(self, record):
"""
Emit a record.
@ -57,46 +59,44 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
try:
# Gather info.
text = self.format(record).replace(self._tail, ' ')
text = self.format(record).replace(self._tail, " ")
if not text: # nothing to log
return
pri = self.encodePriority(self.facility,
self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime());
pri = self.encodePriority(self.facility, self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime())
hostname = socket.gethostname()[: self._max_hostname]
if self.ident:
ident = self.ident[: self._max_ident]
else:
ident = ''
ident = ""
pid = os.getpid() # shouldn't need truncation
# Format the header.
head = {
'pri': pri,
'timestamp': timestamp,
'hostname': hostname,
'ident': ident,
'pid': pid,
"pri": pri,
"timestamp": timestamp,
"hostname": hostname,
"ident": ident,
"pid": pid,
}
msg = self._head_fmt.format(**head).encode('ascii', 'ignore')
msg = self._head_fmt.format(**head).encode("ascii", "ignore")
# Encode text as plain ASCII if possible, else use UTF-8 with BOM.
try:
msg += text.encode('ascii')
msg += text.encode("ascii")
except UnicodeEncodeError:
msg += text.encode('utf-8-sig')
msg += text.encode("utf-8-sig")
# Add length or tail character, if necessary.
if self.socktype != socket.SOCK_DGRAM:
if self._send_length:
msg = ('%d ' % len(msg)).encode('ascii') + msg
msg = ("%d " % len(msg)).encode("ascii") + msg
else:
msg += self._tail.encode('ascii')
msg += self._tail.encode("ascii")
# Send the message.
if self.unixsocket:

View File

@ -10,8 +10,8 @@ logger = logging.getLogger(__name__)
_SOURCE_SPLIT_REGEX = re.compile(
r'(?<=^)"([^"]+)"(?=:|$)'
r'|(?<=:)"([^"]+)"(?=:|$)'
r'|(?<=^)([^:]*)(?=:|$)'
r'|(?<=:)([^:]*)(?=:|$)',
r"|(?<=^)([^:]*)(?=:|$)"
r"|(?<=:)([^:]*)(?=:|$)",
)
@ -26,7 +26,7 @@ def parse_source_args(src):
return [m[0] or m[1] or m[2] or m[3] for m in matches]
class BasePlugin():
class BasePlugin:
def __init__(self, src):
self.source = src
@ -44,8 +44,7 @@ class ReadOnlyTokenFile(BasePlugin):
def _load_targets(self):
if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for
f in os.listdir(self.source)]
cfg_files = [os.path.join(self.source, f) for f in os.listdir(self.source)]
else:
cfg_files = [self.source]
@ -53,12 +52,14 @@ class ReadOnlyTokenFile(BasePlugin):
index = 1
for f in cfg_files:
for line in [l.strip() for l in open(f).readlines()]:
if line and not line.startswith('#'):
if line and not line.startswith("#"):
try:
tok, target = re.split(r':\s', line)
self._targets[tok] = target.strip().rsplit(':', 1)
tok, target = re.split(r":\s", line)
self._targets[tok] = target.strip().rsplit(":", 1)
except ValueError:
logger.error("Syntax error in %s on line %d" % (self.source, index))
logger.error(
"Syntax error in %s on line %d" % (self.source, index)
)
index += 1
def lookup(self, token):
@ -83,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile):
return super().lookup(token)
class TokenFileName(BasePlugin):
# source is a directory
# token is filename
@ -96,7 +98,7 @@ class TokenFileName(BasePlugin):
token = os.path.basename(token)
path = os.path.join(self.source, token)
if os.path.exists(path):
return open(path).read().strip().split(':')
return open(path).read().strip().split(":")
else:
return None
@ -109,8 +111,8 @@ class BaseTokenAPI(BasePlugin):
# in this file can be used w/o unnecessary dependencies
def process_result(self, resp):
host, port = resp.text.split(':')
port = port.encode('ascii','ignore')
host, port = resp.text.split(":")
port = port.encode("ascii", "ignore")
return [host, port]
def lookup(self, token):
@ -130,7 +132,7 @@ class JSONTokenApi(BaseTokenAPI):
def process_result(self, resp):
resp_json = resp.json()
return (resp_json['host'], resp_json['port'])
return (resp_json["host"], resp_json["port"])
class JWTTokenApi(BasePlugin):
@ -145,7 +147,7 @@ class JWTTokenApi(BasePlugin):
key = jwk.JWK()
try:
with open(self.source, 'rb') as key_file:
with open(self.source, "rb") as key_file:
key_data = key_file.read()
except Exception as e:
logger.error("Error loading key file: %s" % str(e))
@ -155,39 +157,41 @@ class JWTTokenApi(BasePlugin):
key.import_from_pem(key_data)
except:
try:
key.import_key(k=key_data.decode('utf-8'),kty='oct')
key.import_key(k=key_data.decode("utf-8"), kty="oct")
except:
logger.error('Failed to correctly parse key data!')
logger.error("Failed to correctly parse key data!")
return None
try:
token = jwt.JWT(key=key, jwt=token)
parsed_header = json.loads(token.header)
if 'enc' in parsed_header:
if "enc" in parsed_header:
# Token is encrypted, so we need to decrypt by passing the claims to a new instance
token = jwt.JWT(key=key, jwt=token.claims)
parsed = json.loads(token.claims)
if 'nbf' in parsed:
if "nbf" in parsed:
# Not Before is present, so we need to check it
if time.time() < parsed['nbf']:
logger.warning('Token can not be used yet!')
if time.time() < parsed["nbf"]:
logger.warning("Token can not be used yet!")
return None
if 'exp' in parsed:
if "exp" in parsed:
# Expiration time is present, so we need to check it
if time.time() > parsed['exp']:
logger.warning('Token has expired!')
if time.time() > parsed["exp"]:
logger.warning("Token has expired!")
return None
return (parsed['host'], parsed['port'])
return (parsed["host"], parsed["port"])
except Exception as e:
logger.error("Failed to parse token: %s" % str(e))
return None
except ImportError:
logger.error("package jwcrypto not found, are you sure you've installed it correctly?")
logger.error(
"package jwcrypto not found, are you sure you've installed it correctly?"
)
return None
@ -251,6 +255,7 @@ class TokenRedis(BasePlugin):
pip install redis
"""
def __init__(self, src):
try:
import redis
@ -285,7 +290,9 @@ class TokenRedis(BasePlugin):
if not self._password:
self._password = None
elif len(fields) == 5:
self._server, self._port, self._db, self._password, self._namespace = fields
self._server, self._port, self._db, self._password, self._namespace = (
fields
)
if not self._port:
self._port = 6379
if not self._db:
@ -301,24 +308,30 @@ class TokenRedis(BasePlugin):
if self._namespace:
self._namespace += ":"
logger.info("TokenRedis backend initialized (%s:%s)" %
(self._server, self._port))
logger.info(
"TokenRedis backend initialized (%s:%s)" % (self._server, self._port)
)
except ValueError:
logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
src)
logger.error(
"The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]"
% src
)
sys.exit()
def lookup(self, token):
try:
import redis
except ImportError:
logger.error("package redis not found, are you sure you've installed them correctly?")
logger.error(
"package redis not found, are you sure you've installed them correctly?"
)
sys.exit()
logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port,
db=self._db, password=self._password)
client = redis.Redis(
host=self._server, port=self._port, db=self._db, password=self._password
)
stuff = client.get(self._namespace + token)
if stuff is None:
return None
@ -330,14 +343,14 @@ class TokenRedis(BasePlugin):
combo = json.loads(responseStr)
host, port = combo["host"].split(":")
except ValueError:
logger.error("Unable to decode JSON token: %s" %
responseStr)
logger.error("Unable to decode JSON token: %s" % responseStr)
return None
except KeyError:
logger.error("Unable to find 'host' key in JSON token: %s" %
responseStr)
logger.error(
"Unable to find 'host' key in JSON token: %s" % responseStr
)
return None
elif re.match(r'\S+:\S+', responseStr):
elif re.match(r"\S+:\S+", responseStr):
host, port = responseStr.split(":")
else:
logger.error("Unable to parse token: %s" % responseStr)
@ -368,7 +381,7 @@ class UnixDomainSocketDirectory(BasePlugin):
if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
return None
return [ 'unix_socket', uds_path ]
return ["unix_socket", uds_path]
except Exception as e:
logger.error("Error finding unix domain socket: %s" % str(e))
return None

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
Python WebSocket library
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
@ -10,7 +10,7 @@ Supports following protocol versions:
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
- http://tools.ietf.org/html/rfc6455
'''
"""
import sys
import array
@ -28,14 +28,19 @@ try:
import numpy
except ImportError:
import warnings
warnings.warn("no 'numpy' module, HyBi protocol will be slower")
numpy = None
class WebSocketWantReadError(ssl.SSLWantReadError):
pass
class WebSocketWantWriteError(ssl.SSLWantWriteError):
pass
class WebSocket:
"""WebSocket protocol socket like class.
@ -73,11 +78,11 @@ class WebSocket:
self._state = "new"
self._partial_msg = b''
self._partial_msg = b""
self._recv_buffer = b''
self._recv_buffer = b""
self._recv_queue = []
self._send_buffer = b''
self._send_buffer = b""
self._previous_sendmsg = None
@ -91,16 +96,22 @@ class WebSocket:
def __getattr__(self, name):
# These methods are just redirected to the underlying socket
if name in ["fileno",
"getpeername", "getsockname",
"getsockopt", "setsockopt",
"gettimeout", "settimeout",
"setblocking"]:
if name in [
"fileno",
"getpeername",
"getsockname",
"getsockopt",
"setsockopt",
"gettimeout",
"settimeout",
"setblocking",
]:
assert self.socket is not None
return getattr(self.socket, name)
else:
raise AttributeError("%s instance has no attribute '%s'" %
(self.__class__.__name__, name))
raise AttributeError(
"%s instance has no attribute '%s'" % (self.__class__.__name__, name)
)
def connect(self, uri, origin=None, protocols=[]):
"""Establishes a new connection to a WebSocket server.
@ -118,8 +129,7 @@ class WebSocket:
connect() must retain the same arguments.
"""
self.client = True;
self.client = True
uri = urlparse(uri)
port = uri.port
@ -140,8 +150,9 @@ class WebSocket:
if uri.scheme in ("wss", "https"):
context = ssl.create_default_context()
self.socket = context.wrap_socket(self.socket,
server_hostname=uri.hostname)
self.socket = context.wrap_socket(
self.socket, server_hostname=uri.hostname
)
self._state = "ssl_handshake"
else:
self._state = "headers"
@ -151,7 +162,7 @@ class WebSocket:
self._state = "headers"
if self._state == "headers":
self._key = ''
self._key = ""
for i in range(16):
self._key += chr(random.randrange(256))
self._key = b64encode(self._key.encode("latin-1")).decode("ascii")
@ -184,10 +195,10 @@ class WebSocket:
if not self._recv():
raise Exception("Socket closed unexpectedly")
if self._recv_buffer.find(b'\r\n\r\n') == -1:
if self._recv_buffer.find(b"\r\n\r\n") == -1:
raise WebSocketWantReadError
(request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1)
(request, self._recv_buffer) = self._recv_buffer.split(b"\r\n", 1)
request = request.decode("latin-1")
words = request.split()
@ -196,17 +207,17 @@ class WebSocket:
if words[1] != "101":
raise Exception("WebSocket request denied: %s" % " ".join(words[1:]))
(headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1)
headers = headers.decode('latin-1') + '\r\n'
(headers, self._recv_buffer) = self._recv_buffer.split(b"\r\n\r\n", 1)
headers = headers.decode("latin-1") + "\r\n"
headers = email.message_from_string(headers)
if headers.get("Upgrade", "").lower() != "websocket":
print(type(headers))
raise Exception("Missing or incorrect upgrade header")
accept = headers.get('Sec-WebSocket-Accept')
accept = headers.get("Sec-WebSocket-Accept")
if accept is None:
raise Exception("Missing Sec-WebSocket-Accept header");
raise Exception("Missing Sec-WebSocket-Accept header")
expected = sha1((self._key + self.GUID).encode("ascii")).digest()
expected = b64encode(expected).decode("ascii")
@ -214,9 +225,9 @@ class WebSocket:
del self._key
if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header");
raise Exception("Invalid Sec-WebSocket-Accept header")
self.protocol = headers.get('Sec-WebSocket-Protocol')
self.protocol = headers.get("Sec-WebSocket-Protocol")
if len(protocols) == 0:
if self.protocol is not None:
raise Exception("Unexpected Sec-WebSocket-Protocol header")
@ -256,34 +267,34 @@ class WebSocket:
if headers.get("upgrade", "").lower() != "websocket":
raise Exception("Missing or incorrect upgrade header")
ver = headers.get('Sec-WebSocket-Version')
ver = headers.get("Sec-WebSocket-Version")
if ver is None:
raise Exception("Missing Sec-WebSocket-Version header");
raise Exception("Missing Sec-WebSocket-Version header")
# HyBi-07 report version 7
# HyBi-08 - HyBi-12 report version 8
# HyBi-13 reports version 13
if ver in ['7', '8', '13']:
if ver in ["7", "8", "13"]:
self.version = "hybi-%02d" % int(ver)
else:
raise Exception("Unsupported protocol version %s" % ver)
key = headers.get('Sec-WebSocket-Key')
key = headers.get("Sec-WebSocket-Key")
if key is None:
raise Exception("Missing Sec-WebSocket-Key header");
raise Exception("Missing Sec-WebSocket-Key header")
# Generate the hash value for the accept header
accept = sha1((key + self.GUID).encode("ascii")).digest()
accept = b64encode(accept).decode("ascii")
self.protocol = ''
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
self.protocol = ""
protocols = headers.get("Sec-WebSocket-Protocol", "").split(",")
if protocols:
self.protocol = self.select_subprotocol(protocols)
# We are required to choose one of the protocols
# presented by the client
if self.protocol not in protocols:
raise Exception('Invalid protocol selected')
raise Exception("Invalid protocol selected")
self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
@ -461,7 +472,7 @@ class WebSocket:
def send_request(self, type, path):
self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path))
def ping(self, data=b''):
def ping(self, data=b""):
"""Write a ping message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
@ -486,7 +497,7 @@ class WebSocket:
self._previous_sendmsg = data
raise
def pong(self, data=b''):
def pong(self, data=b""):
"""Write a pong message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
@ -540,7 +551,7 @@ class WebSocket:
self._sent_close = True
msg = b''
msg = b""
if code is not None:
msg += struct.pack(">H", code)
if reason is not None:
@ -602,7 +613,7 @@ class WebSocket:
frame = self._decode_hybi(self._recv_buffer)
if frame is None:
break
self._recv_buffer = self._recv_buffer[frame['length']:]
self._recv_buffer = self._recv_buffer[frame["length"] :]
self._recv_queue.append(frame)
return True
@ -612,29 +623,39 @@ class WebSocket:
while self._recv_queue:
frame = self._recv_queue.pop(0)
if not self.client and not frame['masked']:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked")
if not self.client and not frame["masked"]:
self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked"
)
continue
if self.client and frame['masked']:
if self.client and frame["masked"]:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked")
continue
if frame["opcode"] == 0x0:
if not self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame")
self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Unexpected continuation frame",
)
continue
self._partial_msg += frame["payload"]
if frame["fin"]:
msg = self._partial_msg
self._partial_msg = b''
self._partial_msg = b""
return msg
elif frame["opcode"] == 0x1:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported"
)
elif frame["opcode"] == 0x2:
if self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame")
self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame"
)
continue
if frame["fin"]:
@ -652,7 +673,9 @@ class WebSocket:
return None
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close"
)
continue
code = None
@ -664,7 +687,11 @@ class WebSocket:
try:
reason = reason.decode("UTF-8")
except UnicodeDecodeError:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close")
self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Invalid UTF-8 in close",
)
continue
if code is None:
@ -679,18 +706,26 @@ class WebSocket:
return None
elif frame["opcode"] == 0x9:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping"
)
continue
self.handle_ping(frame["payload"])
elif frame["opcode"] == 0xA:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong"
)
continue
self.handle_pong(frame["payload"])
else:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"])
self.shutdown(
socket.SHUT_RDWR,
1003,
"Unsupported: Unknown opcode 0x%02x" % frame["opcode"],
)
raise WebSocketWantReadError
@ -731,7 +766,7 @@ class WebSocket:
def _sendmsg(self, opcode, msg):
# Sends a standard data message
if self.client:
mask = b''
mask = b""
for i in range(4):
mask += random.randrange(256).to_bytes()
frame = self._encode_hybi(opcode, msg, mask)
@ -755,28 +790,29 @@ class WebSocket:
plen = len(buf)
pstart = 0
pend = plen
b = c = b''
b = c = b""
if plen >= 4:
dtype=numpy.dtype('<u4')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
dtype = numpy.dtype("<u4")
if sys.byteorder == "big":
dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=1)
data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
# b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tobytes()
if plen % 4:
dtype=numpy.dtype('B')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
dtype = numpy.dtype("B")
if sys.byteorder == "big":
dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
data = numpy.frombuffer(buf, dtype,
offset=plen - (plen % 4), count=(plen % 4))
data = numpy.frombuffer(
buf, dtype, offset=plen - (plen % 4), count=(plen % 4)
)
c = numpy.bitwise_xor(data, mask).tobytes()
return b + c
else:
# Slower fallback
data = array.array('B')
data = array.array("B")
data.frombytes(buf)
for i in range(len(data)):
data[i] ^= mask[i % 4]
@ -793,7 +829,7 @@ class WebSocket:
0xA - pong
"""
b1 = opcode & 0x0f
b1 = opcode & 0x0F
if fin:
b1 |= 0x80
@ -804,11 +840,11 @@ class WebSocket:
payload_len = len(buf)
if payload_len <= 125:
header = struct.pack('>BB', b1, payload_len | mask_bit)
header = struct.pack(">BB", b1, payload_len | mask_bit)
elif payload_len > 125 and payload_len < 65536:
header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len)
header = struct.pack(">BBH", b1, 126 | mask_bit, payload_len)
elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len)
header = struct.pack(">BBQ", b1, 127 | mask_bit, payload_len)
if mask_key is not None:
return header + mask_key + buf
@ -825,11 +861,7 @@ class WebSocket:
'payload' : decoded_buffer}
"""
f = {'fin' : 0,
'opcode' : 0,
'masked' : False,
'length' : 0,
'payload' : None}
f = {"fin": 0, "opcode": 0, "masked": False, "length": 0, "payload": None}
blen = len(buf)
hlen = 2
@ -838,39 +870,38 @@ class WebSocket:
return None
b1, b2 = struct.unpack(">BB", buf[:2])
f['opcode'] = b1 & 0x0f
f['fin'] = not not (b1 & 0x80)
f['masked'] = not not (b2 & 0x80)
f["opcode"] = b1 & 0x0F
f["fin"] = not not (b1 & 0x80)
f["masked"] = not not (b2 & 0x80)
if f['masked']:
if f["masked"]:
hlen += 4
if blen < hlen:
return None
length = b2 & 0x7f
length = b2 & 0x7F
if length == 126:
hlen += 2
if blen < hlen:
return None
length, = struct.unpack('>H', buf[2:4])
(length,) = struct.unpack(">H", buf[2:4])
elif length == 127:
hlen += 8
if blen < hlen:
return None
length, = struct.unpack('>Q', buf[2:10])
(length,) = struct.unpack(">Q", buf[2:10])
f['length'] = hlen + length
f["length"] = hlen + length
if blen < f['length']:
if blen < f["length"]:
return None
if f['masked']:
if f["masked"]:
# unmask payload
mask_key = buf[hlen - 4 : hlen]
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key)
f["payload"] = self._unmask(buf[hlen : (hlen + length)], mask_key)
else:
f['payload'] = buf[hlen:(hlen+length)]
f["payload"] = buf[hlen : (hlen + length)]
return f

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
A WebSocket to TCP socket proxy with support for "wss://" encryption.
Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -9,7 +9,7 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn
@ -20,8 +20,8 @@ from websockify import websockifyserver
from websockify import auth_plugins as auth
from urllib.parse import parse_qs, urlparse
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
buffer_size = 65536
traffic_legend = """
@ -38,7 +38,7 @@ Traffic Legend:
def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
self.send_header("Content-Type", "text/html")
for name, val in ex.headers.items():
self.send_header(name, val)
@ -49,7 +49,7 @@ Traffic Legend:
return
host, port = self.get_target(self.server.token_plugin)
if host == 'unix_socket':
if host == "unix_socket":
self.server.unix_target = port
else:
@ -62,7 +62,7 @@ Traffic Legend:
# clear out any existing SSL_ headers that the client might
# have maliciously set
ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ]
ssl_headers = [h for h in self.headers if h.startswith("SSL_")]
for h in ssl_headers:
del self.headers[h]
@ -70,19 +70,21 @@ Traffic Legend:
# get client certificate data
client_cert_data = self.request.getpeercert()
# extract subject information
client_cert_subject = client_cert_data['subject']
client_cert_subject = client_cert_data["subject"]
# flatten data structure
client_cert_subject = dict([x[0] for x in client_cert_subject])
# add common name to headers (apache +StdEnvVars style)
self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName']
self.headers["SSL_CLIENT_S_DN_CN"] = client_cert_subject["commonName"]
except (TypeError, AttributeError, KeyError):
# not a SSL connection or client presented no certificate with valid data
pass
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
headers=self.headers,
target_host=self.server.target_host,
target_port=self.server.target_port,
)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
@ -96,26 +98,37 @@ Traffic Legend:
# Connect to the target
if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
msg = "connecting to command: '%s' (port %s)" % (
" ".join(self.server.wrap_cmd),
self.server.target_port,
)
elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.server.unix_target
else:
msg = "connecting to: %s:%s" % (
self.server.target_host, self.server.target_port)
self.server.target_host,
self.server.target_port,
)
if self.server.ssl_target:
msg += " (using SSL)"
self.log_message(msg)
try:
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host,
tsock = websockifyserver.WebSockifyServer.socket(
self.server.target_host,
self.server.target_port,
connect=True,
use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target)
unix_socket=self.server.unix_target,
)
except Exception as e:
self.log_message("Failed to connect to %s:%s: %s",
self.server.target_host, self.server.target_port, e)
self.log_message(
"Failed to connect to %s:%s: %s",
self.server.target_host,
self.server.target_port,
e,
)
raise self.CClose(1011, "Failed to connect to downstream server")
# Option unavailable when listening to unix socket
@ -134,8 +147,11 @@ Traffic Legend:
tsock.shutdown(socket.SHUT_RDWR)
tsock.close()
if self.verbose:
self.log_message("%s:%s: Closed target",
self.server.target_host, self.server.target_port)
self.log_message(
"%s:%s: Closed target",
self.server.target_host,
self.server.target_port,
)
def get_target(self, target_plugin):
"""
@ -149,20 +165,20 @@ Traffic Legend:
if self.host_token:
# Use hostname as token
token = self.headers.get('Host')
token = self.headers.get("Host")
# Remove port from hostname, as it'll always be the one where
# websockify listens (unless something between the client and
# websockify is redirecting traffic, but that's beside the point)
if token:
token = token.partition(':')[0]
token = token.partition(":")[0]
else:
# Extract the token parameter from url
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
if 'token' in args and len(args['token']):
token = args['token'][0].rstrip('\n')
if "token" in args and len(args["token"]):
token = args["token"][0].rstrip("\n")
else:
token = None
@ -200,13 +216,15 @@ Traffic Legend:
self.heartbeat = now + self.server.heartbeat
self.send_ping()
if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.request)
if tqueue:
wlist.append(target)
if cqueue or c_pend:
wlist.append(self.request)
try:
ins, outs, excepts = select.select(rlist, wlist, [], 1)
except OSError:
exc = sys.exc_info()[1]
if hasattr(exc, 'errno'):
if hasattr(exc, "errno"):
err = exc.errno
else:
err = exc[0]
@ -216,7 +234,8 @@ Traffic Legend:
else:
continue
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if self.request in outs:
# Send queued target data to the client
@ -230,8 +249,7 @@ Traffic Legend:
tqueue.extend(bufs)
if closed:
while (len(tqueue) != 0):
while len(tqueue) != 0:
# Send queued client data to the target
dat = tqueue.pop(0)
sent = target.send(dat)
@ -244,10 +262,12 @@ Traffic Legend:
# TODO: What about blocking on client socket?
if self.verbose:
self.log_message("%s:%s: Client closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(closed['code'], closed['reason'])
self.log_message(
"%s:%s: Client closed connection",
self.server.target_host,
self.server.target_port,
)
raise self.CClose(closed["code"], closed["reason"])
if target in outs:
# Send queued client data to the target
@ -260,29 +280,31 @@ Traffic Legend:
tqueue.insert(0, dat[sent:])
self.print_traffic(".>")
if target in ins:
# Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size)
if len(buf) == 0:
# Target socket closed, flushing queues and closing client-side websocket
# Send queued target data to the client
if len(cqueue) != 0:
c_pend = True
while(c_pend):
while c_pend:
c_pend = self.send_frames(cqueue)
cqueue = []
if self.verbose:
self.log_message("%s:%s: Target closed connection",
self.server.target_host, self.server.target_port)
self.log_message(
"%s:%s: Target closed connection",
self.server.target_host,
self.server.target_port,
)
raise self.CClose(1000, "Target closed")
cqueue.append(buf)
self.print_traffic("{")
class WebSocketProxy(websockifyserver.WebSockifyServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
@ -293,27 +315,29 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop("ssl_target", None)
self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.host_token = kwargs.pop('host_token', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.token_plugin = kwargs.pop("token_plugin", None)
self.host_token = kwargs.pop("host_token", None)
self.auth_plugin = kwargs.pop("auth_plugin", None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0])
rebinder_path = [os.path.join(wsdir, "..", "lib"),
rebinder_path = [
os.path.join(wsdir, "..", "lib"),
os.path.join(wsdir, "..", "lib", "websockify"),
os.path.join(wsdir, ".."),
wsdir]
wsdir,
]
self.rebinder = None
for rdir in rebinder_path:
@ -329,17 +353,22 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.target_host = "127.0.0.1" # Loopback
# Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
sock.bind(("", 0))
self.target_port = sock.getsockname()[1]
sock.close()
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ])
ld_preloads = filter(
None, [self.rebinder, os.environ.get("LD_PRELOAD", None)]
)
os.environ.update({
os.environ.update(
{
"LD_PRELOAD": os.pathsep.join(ld_preloads),
"REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)})
"REBIND_OLD_PORT": str(kwargs["listen_port"]),
"REBIND_NEW_PORT": str(self.target_port),
}
)
super().__init__(RequestHandlerClass, *args, **kwargs)
@ -348,7 +377,8 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.wrap_times.append(time.time())
self.wrap_times.pop(0)
self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup
)
self.spawn_message = True
def started(self):
@ -371,10 +401,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
if self.token_plugin:
msg = " - proxying from %s to targets generated by %s" % (
src_string, type(self.token_plugin).__name__)
src_string,
type(self.token_plugin).__name__,
)
else:
msg = " - proxying from %s to %s" % (
src_string, dst_string)
msg = " - proxying from %s to %s" % (src_string, dst_string)
if self.ssl_target:
msg += " (using SSL)"
@ -418,15 +449,25 @@ def _subprocess_setup():
SSL_OPTIONS = {
'default': ssl.OP_ALL,
'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1,
'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1,
'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2,
"default": ssl.OP_ALL,
"tlsv1_1": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1,
"tlsv1_2": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1,
"tlsv1_3": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1
| ssl.OP_NO_TLSv1_2,
}
def select_ssl_version(version):
"""Returns SSL options for the most secure TSL version available on this
Python version"""
@ -439,11 +480,11 @@ def select_ssl_version(version):
keys.sort()
fallback = keys[-1]
logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.warn("TLS version %s unsupported. Falling back to %s",
version, fallback)
logger.warn("TLS version %s unsupported. Falling back to %s", version, fallback)
return SSL_OPTIONS[fallback]
def websockify_init():
# Setup basic logging to stderr.
stderr_handler = logging.StreamHandler()
@ -464,105 +505,198 @@ def websockify_init():
usage += "\n %prog [options]"
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages")
parser.add_option("--traffic", action="store_true",
help="per frame traffic")
parser.add_option("--record",
help="record sessions to FILE.[session_number]", metavar="FILE")
parser.add_option("--daemon", "-D",
dest="daemon", action="store_true",
help="become a daemon (background process)")
parser.add_option("--run-once", action="store_true",
help="handle a single WebSocket connection and exit")
parser.add_option("--timeout", type=int, default=0,
help="after TIMEOUT seconds exit when not connected")
parser.add_option("--idle-timeout", type=int, default=0,
help="server exits after TIMEOUT seconds if there are no "
"active connections")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--key-password", default=None,
help="SSL key password")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted client connections")
parser.add_option("--ssl-target", action="store_true",
help="connect to SSL target as SSL client")
parser.add_option("--verify-client", action="store_true",
parser.add_option("--verbose", "-v", action="store_true", help="verbose messages")
parser.add_option("--traffic", action="store_true", help="per frame traffic")
parser.add_option(
"--record", help="record sessions to FILE.[session_number]", metavar="FILE"
)
parser.add_option(
"--daemon",
"-D",
dest="daemon",
action="store_true",
help="become a daemon (background process)",
)
parser.add_option(
"--run-once",
action="store_true",
help="handle a single WebSocket connection and exit",
)
parser.add_option(
"--timeout",
type=int,
default=0,
help="after TIMEOUT seconds exit when not connected",
)
parser.add_option(
"--idle-timeout",
type=int,
default=0,
help="server exits after TIMEOUT seconds if there are no active connections",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option("--key-password", default=None, help="SSL key password")
parser.add_option(
"--ssl-only",
action="store_true",
help="disallow non-encrypted client connections",
)
parser.add_option(
"--ssl-target", action="store_true", help="connect to SSL target as SSL client"
)
parser.add_option(
"--verify-client",
action="store_true",
help="require encrypted client to present a valid certificate "
"(needs Python 2.7.9 or newer or Python 3.4 or newer)")
parser.add_option("--cafile", metavar="FILE",
"(needs Python 2.7.9 or newer or Python 3.4 or newer)",
)
parser.add_option(
"--cafile",
metavar="FILE",
help="file of concatenated certificates of authorities trusted "
"for validating clients (only effective with --verify-client). "
"If omitted, system default list of CAs is used.")
parser.add_option("--ssl-version", type="choice", default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)")
parser.add_option("--ssl-ciphers", action="store",
"If omitted, system default list of CAs is used.",
)
parser.add_option(
"--ssl-version",
type="choice",
default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"],
action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)",
)
parser.add_option(
"--ssl-ciphers",
action="store",
help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd",
help="inetd mode, receive listening socket from stdin", action="store_true")
parser.add_option("--web", default=None, metavar="DIR",
help="run webserver on same port. Serve files from DIR.")
parser.add_option("--web-auth", action="store_true",
help="require authentication to access webserver.")
parser.add_option("--wrap-mode", default="exit", metavar="MODE",
"supported ciphers run `openssl ciphers`",
)
parser.add_option(
"--unix-listen", help="listen to unix socket", metavar="FILE", default=None
)
parser.add_option(
"--unix-listen-mode",
default=None,
help="specify mode for unix socket (defaults to 0600)",
)
parser.add_option(
"--unix-target", help="connect to unix socket target", metavar="FILE"
)
parser.add_option(
"--inetd",
help="inetd mode, receive listening socket from stdin",
action="store_true",
)
parser.add_option(
"--web",
default=None,
metavar="DIR",
help="run webserver on same port. Serve files from DIR.",
)
parser.add_option(
"--web-auth",
action="store_true",
help="require authentication to access webserver.",
)
parser.add_option(
"--wrap-mode",
default="exit",
metavar="MODE",
choices=["exit", "ignore", "respawn"],
help="action to take when the wrapped program exits "
"or daemonizes: exit (default), ignore, respawn")
parser.add_option("--prefer-ipv6", "-6",
action="store_true", dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
parser.add_option("--target-config", metavar="FILE",
"or daemonizes: exit (default), ignore, respawn",
)
parser.add_option(
"--prefer-ipv6",
"-6",
action="store_true",
dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr",
)
parser.add_option(
"--libserver",
action="store_true",
help="use Python library SocketServer engine",
)
parser.add_option(
"--target-config",
metavar="FILE",
dest="target_cfg",
help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form "
"(DEPRECATED: use `--token-plugin TokenFile --token-source "
" path/to/token/file` instead)")
parser.add_option("--token-plugin", default=None, metavar="CLASS",
" path/to/token/file` instead)",
)
parser.add_option(
"--token-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.token_plugins, "
"such as TokenFile, to process tokens into host:port pairs")
parser.add_option("--token-source", default=None, metavar="ARG",
help="an argument to be passed to the token plugin "
"on instantiation")
parser.add_option("--host-token", action="store_true",
"such as TokenFile, to process tokens into host:port pairs",
)
parser.add_option(
"--token-source",
default=None,
metavar="ARG",
help="an argument to be passed to the token plugin on instantiation",
)
parser.add_option(
"--host-token",
action="store_true",
help="use the host HTTP header as token instead of the "
"token URL query parameter")
parser.add_option("--auth-plugin", default=None, metavar="CLASS",
"token URL query parameter",
)
parser.add_option(
"--auth-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.auth_plugins, "
"such as BasicHTTPAuth, to determine if a connection is allowed")
parser.add_option("--auth-source", default=None, metavar="ARG",
help="an argument to be passed to the auth plugin "
"on instantiation")
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds")
parser.add_option("--log-file", metavar="FILE",
"such as BasicHTTPAuth, to determine if a connection is allowed",
)
parser.add_option(
"--auth-source",
default=None,
metavar="ARG",
help="an argument to be passed to the auth plugin on instantiation",
)
parser.add_option(
"--heartbeat",
type=int,
default=0,
metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds",
)
parser.add_option(
"--log-file",
metavar="FILE",
dest="log_file",
help="File where logs will be saved")
parser.add_option("--syslog", default=None, metavar="SERVER",
help="File where logs will be saved",
)
parser.add_option(
"--syslog",
default=None,
metavar="SERVER",
help="Log to syslog server. SERVER can be local socket, "
"such as /dev/log, or a UDP host:port pair.")
parser.add_option("--legacy-syslog", action="store_true",
"such as /dev/log, or a UDP host:port pair.",
)
parser.add_option(
"--legacy-syslog",
action="store_true",
help="Use the old syslog protocol instead of RFC 5424. "
"Use this if the messages produced by websockify seem abnormal.")
parser.add_option("--file-only", action="store_true",
help="use this to disable directory listings in web server.")
"Use this if the messages produced by websockify seem abnormal.",
)
parser.add_option(
"--file-only",
action="store_true",
help="use this to disable directory listings in web server.",
)
(opts, args) = parser.parse_args()
# Validate options.
if opts.token_source and not opts.token_plugin:
@ -583,11 +717,9 @@ def websockify_init():
if opts.legacy_syslog and not opts.syslog:
parser.error("You must use --syslog to use --legacy-syslog")
opts.ssl_options = select_ssl_version(opts.ssl_version)
del opts.ssl_version
if opts.log_file:
# Setup logging to user-specified file.
opts.log_file = os.path.abspath(opts.log_file)
@ -601,9 +733,9 @@ def websockify_init():
if opts.syslog:
# Determine how to connect to syslog...
if opts.syslog.count(':'):
if opts.syslog.count(":"):
# User supplied a host:port pair.
syslog_host, syslog_port = opts.syslog.rsplit(':', 1)
syslog_host, syslog_port = opts.syslog.rsplit(":", 1)
try:
syslog_port = int(syslog_port)
except ValueError:
@ -622,10 +754,12 @@ def websockify_init():
syslog_facility = WebsockifySysLogHandler.LOG_USER
# Start logging to syslog.
syslog_handler = WebsockifySysLogHandler(address=syslog_dest,
syslog_handler = WebsockifySysLogHandler(
address=syslog_dest,
facility=syslog_facility,
ident='websockify',
legacy=opts.legacy_syslog)
ident="websockify",
legacy=opts.legacy_syslog,
)
syslog_handler.setLevel(logging.DEBUG)
syslog_handler.setFormatter(log_formatter)
root = logging.getLogger()
@ -638,24 +772,23 @@ def websockify_init():
root = logging.getLogger()
root.setLevel(logging.DEBUG)
# Transform to absolute path as daemon may chdir
if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg)
if opts.target_cfg:
opts.token_plugin = 'TokenFile'
opts.token_plugin = "TokenFile"
opts.token_source = opts.target_cfg
del opts.target_cfg
if sys.argv.count('--'):
if sys.argv.count("--"):
opts.wrap_cmd = args[1:]
else:
opts.wrap_cmd = None
if not websockifyserver.ssl and opts.ssl_target:
parser.error("SSL target requested and Python SSL module not loaded.");
parser.error("SSL target requested and Python SSL module not loaded.")
if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert)
@ -677,11 +810,11 @@ def websockify_init():
parser.error("Too few arguments")
arg = args.pop(0)
# Parse host:port and convert ports to numbers
if arg.count(':') > 0:
opts.listen_host, opts.listen_port = arg.rsplit(':', 1)
opts.listen_host = opts.listen_host.strip('[]')
if arg.count(":") > 0:
opts.listen_host, opts.listen_port = arg.rsplit(":", 1)
opts.listen_host = opts.listen_host.strip("[]")
else:
opts.listen_host, opts.listen_port = '', arg
opts.listen_host, opts.listen_port = "", arg
try:
opts.listen_port = int(opts.listen_port)
@ -697,9 +830,9 @@ def websockify_init():
if len(args) < 1:
parser.error("Too few arguments")
arg = args.pop(0)
if arg.count(':') > 0:
opts.target_host, opts.target_port = arg.rsplit(':', 1)
opts.target_host = opts.target_host.strip('[]')
if arg.count(":") > 0:
opts.target_host, opts.target_port = arg.rsplit(":", 1)
opts.target_host = opts.target_host.strip("[]")
else:
parser.error("Error parsing target")
@ -712,11 +845,10 @@ def websockify_init():
parser.error("Too many arguments")
if opts.token_plugin is not None:
if '.' not in opts.token_plugin:
opts.token_plugin = (
'websockify.token_plugins.%s' % opts.token_plugin)
if "." not in opts.token_plugin:
opts.token_plugin = "websockify.token_plugins.%s" % opts.token_plugin
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit(".", 1)
__import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
@ -726,10 +858,10 @@ def websockify_init():
del opts.token_source
if opts.auth_plugin is not None:
if '.' not in opts.auth_plugin:
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
if "." not in opts.auth_plugin:
opts.auth_plugin = "websockify.auth_plugins.%s" % opts.auth_plugin
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit(".", 1)
__import__(auth_plugin_module)
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
@ -759,32 +891,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop("ssl_target", None)
self.token_plugin = kwargs.pop("token_plugin", None)
self.auth_plugin = kwargs.pop("auth_plugin", None)
self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = None
self.auth_plugin = None
self.daemon = False
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
listen_host = kwargs.pop("listen_host", "")
listen_port = kwargs.pop("listen_port", None)
web = kwargs.pop("web", "")
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
self.verbose = kwargs.pop("verbose", False)
record = kwargs.pop("record", "")
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.run_once = kwargs.pop("run_once", False)
self.handler_id = 0
for arg in kwargs.keys():
@ -795,12 +927,11 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
super().__init__((listen_host, listen_port), RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
super().process_request(request, client_address)
if __name__ == '__main__':
if __name__ == "__main__":
websockify_init()

View File

@ -1,19 +1,25 @@
#!/usr/bin/env python
'''
"""
Python WebSocket server base
Copyright 2011 Joel Martin
Copyright 2016-2018 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
'''
"""
import sys
from http.server import BaseHTTPRequestHandler, HTTPServer
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocket import (
WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
class HttpWebSocket(WebSocket):
"""Class to glue websocket and http request functionality together"""
def __init__(self, request_handler):
super().__init__()
@ -62,8 +68,10 @@ class WebSocketRequestHandlerMixIn:
# Checks if it is a websocket request and redirects
self.do_GET = self._real_do_GET
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
if (
self.headers.get("upgrade")
and self.headers.get("upgrade").lower() == "websocket"
):
self.handle_upgrade()
else:
self.do_GET()
@ -100,11 +108,13 @@ class WebSocketRequestHandlerMixIn:
"""
pass
# Convenient ready made classes
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
BaseHTTPRequestHandler):
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler):
pass
class WebSocketServer(HTTPServer):
pass

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
Python WebSocket server base with support for "wss://" encryption.
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
@ -10,35 +10,39 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import os, sys, time, errno, signal, socket, select, logging
import multiprocessing
from http.server import SimpleHTTPRequestHandler
# Degraded functionality if these imports are missing
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'),
('resource', 'daemonizing is disabled')]:
for mod, msg in [
("ssl", "TLS/SSL/wss is disabled"),
("resource", "daemonizing is disabled"),
]:
try:
globals()[mod] = __import__(mod)
except ImportError:
globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg))
if sys.platform == 'win32':
if sys.platform == "win32":
# make sockets pickle-able/inheritable
import multiprocessing.reduction
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocketserver import WebSocketRequestHandlerMixIn
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
def select_subprotocol(self, protocols):
# Handle old websockify clients that still specify a sub-protocol
if 'binary' in protocols:
return 'binary'
if "binary" in protocols:
return "binary"
else:
return ''
return ""
# HTTP handler with WebSocket upgrade support
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
@ -56,6 +60,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
* run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename
"""
server_version = "WebSockify"
protocol_version = "HTTP/1.1"
@ -87,7 +92,10 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
super().__init__(req, addr, server)
def log_message(self, format, *args):
self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args))
self.logger.info(
"%s - - [%s] %s"
% (self.client_address[0], self.log_date_time_string(), format % args)
)
#
# WebSocketRequestHandler logging/output functions
@ -130,7 +138,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
for buf in bufs:
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr))
self.send_parts.append(buf)
@ -165,15 +178,22 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
break
if buf is None:
closed = {'code': self.request.close_code,
'reason': self.request.close_reason}
closed = {
"code": self.request.close_code,
"reason": self.request.close_reason,
}
return bufs, closed
self.print_traffic("}")
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr))
bufs.append(buf)
@ -183,15 +203,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return bufs, closed
def send_close(self, code=1000, reason=''):
def send_close(self, code=1000, reason=""):
"""Send a WebSocket orderly close frame."""
self.request.shutdown(socket.SHUT_RDWR, code, reason)
def send_pong(self, data=b''):
def send_pong(self, data=b""):
"""Send a WebSocket pong frame."""
self.request.pong(data)
def send_ping(self, data=b''):
def send_ping(self, data=b""):
"""Send a WebSocket ping frame."""
self.request.ping(data)
@ -224,17 +244,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection", client_addr,
self.stype)
if self.path != '/':
self.log_message("%s: %s WebSocket connection", client_addr, self.stype)
if self.path != "/":
self.log_message("%s: Path: '%s'", client_addr, self.path)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
fname = "%s.%s" % (self.record, self.handler_id)
self.log_message("opening record file: %s", fname)
self.rec = open(fname, 'w+')
self.rec = open(fname, "w+")
self.rec.write("var VNC_frame_data = [\n")
try:
@ -262,7 +280,9 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
def new_websocket_client(self):
"""Do something with a WebSockets client connection."""
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
raise Exception(
"WebSocketRequestHandler.new_websocket_client() must be overloaded"
)
def validate_connection(self):
"""Ensure that the connection has a valid token, and set the target."""
@ -296,12 +316,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else:
super().handle()
def log_request(self, code='-', size='-'):
def log_request(self, code="-", size="-"):
if self.verbose:
super().log_request(code, size)
class WebSockifyServer():
class WebSockifyServer:
"""
WebSockets server class.
As an alternative, the standard library SocketServer can be used
@ -317,17 +337,38 @@ class WebSockifyServer():
class Terminate(Exception):
pass
def __init__(self, RequestHandlerClass, listen_fd=None,
listen_host='', listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', key_password=None, ssl_only=None,
verify_client=False, cafile=None,
daemon=False, record='', web='', web_auth=False,
def __init__(
self,
RequestHandlerClass,
listen_fd=None,
listen_host="",
listen_port=None,
source_is_ipv6=False,
verbose=False,
cert="",
key="",
key_password=None,
ssl_only=None,
verify_client=False,
cafile=None,
daemon=False,
record="",
web="",
web_auth=False,
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None):
run_once=False,
timeout=0,
idle_timeout=0,
traffic=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
ssl_ciphers=None,
ssl_options=0,
unix_listen=None,
unix_listen_mode=None,
):
# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
@ -366,7 +407,7 @@ class WebSockifyServer():
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.web = self.record = self.cafile = ''
self.web = self.record = self.cafile = ""
if key:
self.key = os.path.abspath(key)
if web:
@ -393,11 +434,12 @@ class WebSockifyServer():
elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen)
else:
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
self.msg(" - Listen on %s:%s", self.listen_host, self.listen_port)
if self.web:
if self.file_only:
self.msg(" - Web server (no directory listings). Web root: %s", self.web)
self.msg(
" - Web server (no directory listings). Web root: %s", self.web
)
else:
self.msg(" - Web server. Web root: %s", self.web)
if ssl:
@ -420,34 +462,45 @@ class WebSockifyServer():
@staticmethod
def get_logger():
return logging.getLogger("%s.%s" % (
WebSockifyServer.log_prefix,
WebSockifyServer.__class__.__name__))
return logging.getLogger(
"%s.%s" % (WebSockifyServer.log_prefix, WebSockifyServer.__class__.__name__)
)
@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None):
def socket(
host,
port=None,
connect=False,
prefer_ipv6=False,
unix_socket=None,
unix_socket_mode=None,
unix_socket_listen=False,
use_ssl=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
):
"""Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
"""
flags = 0
if host == '':
if host == "":
host = None
if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port")
if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded.");
raise Exception("SSL socket requested but Python SSL module not loaded.")
if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)")
if not connect:
flags = flags | socket.AI_PASSIVE
if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
addrs = socket.getaddrinfo(
host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags
)
if not addrs:
raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0])
@ -458,14 +511,11 @@ class WebSockifyServer():
if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
tcp_keepcnt)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
if tcp_keepidle:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
tcp_keepidle)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
if tcp_keepintvl:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
tcp_keepintvl)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
if connect:
sock.connect(addrs[0][4])
@ -497,8 +547,7 @@ class WebSockifyServer():
return sock
@staticmethod
def daemonize(keepfd=None, chdir='/'):
def daemonize(keepfd=None, chdir="/"):
if keepfd is None:
keepfd = []
@ -506,14 +555,16 @@ class WebSockifyServer():
if chdir:
os.chdir(chdir)
else:
os.chdir('/')
os.chdir("/")
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
if os.fork() > 0:
os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
if os.fork() > 0:
os._exit(0) # Parent exits
# Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN)
@ -521,14 +572,16 @@ class WebSockifyServer():
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
if maxfd == resource.RLIM_INFINITY:
maxfd = 256
for fd in reversed(range(maxfd)):
try:
if fd not in keepfd:
os.close(fd)
except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise
if exc.errno != errno.EBADF:
raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
@ -567,8 +620,7 @@ class WebSockifyServer():
if not ssl:
raise self.EClose("SSL connection but no 'ssl' module")
if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found"
% self.cert)
raise self.EClose("SSL connection but '%s' not found" % self.cert)
retsock = None
try:
# create new-style SSL wrapping for extended features
@ -576,16 +628,16 @@ class WebSockifyServer():
if self.ssl_ciphers is not None:
context.set_ciphers(self.ssl_ciphers)
context.options = self.ssl_options
context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password)
context.load_cert_chain(
certfile=self.cert, keyfile=self.key, password=self.key_password
)
if self.verify_client:
context.verify_mode = ssl.CERT_REQUIRED
if self.cafile:
context.load_verify_locations(cafile=self.cafile)
else:
context.set_default_verify_paths()
retsock = context.wrap_socket(
sock,
server_side=True)
retsock = context.wrap_socket(sock, server_side=True)
except ssl.SSLError:
_, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF:
@ -629,7 +681,6 @@ class WebSockifyServer():
"""Same as msg() but as warning."""
self.logger.log(logging.WARN, *args, **kwargs)
#
# Events that can/should be overridden in sub-classes
#
@ -661,7 +712,7 @@ class WebSockifyServer():
while result[0]:
self.vmsg("Reaped child process %s" % result[0])
result = os.waitpid(-1, os.WNOHANG)
except (OSError):
except OSError:
pass
def do_SIGINT(self, sig, stack):
@ -693,7 +744,6 @@ class WebSockifyServer():
self.msg("handler exception: %s" % str(exc))
self.vmsg("exception", exc_info=True)
finally:
if client and client != startsock:
# Close the SSL wrapped socket
# Original socket closed by caller
@ -721,19 +771,27 @@ class WebSockifyServer():
try:
if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
lsock = socket.fromfd(
self.listen_fd, socket.AF_INET, socket.SOCK_STREAM
)
elif self.unix_listen != None:
lsock = self.socket(host=None,
lsock = self.socket(
host=None,
unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True)
unix_socket_listen=True,
)
else:
lsock = self.socket(self.listen_host, self.listen_port, False,
lsock = self.socket(
self.listen_host,
self.listen_port,
False,
self.prefer_ipv6,
tcp_keepalive=self.tcp_keepalive,
tcp_keepcnt=self.tcp_keepcnt,
tcp_keepidle=self.tcp_keepidle,
tcp_keepintvl=self.tcp_keepintvl)
tcp_keepintvl=self.tcp_keepintvl,
)
except OSError as e:
self.msg("Openening socket failed: %s", str(e))
self.vmsg("exception", exc_info=True)
@ -751,13 +809,13 @@ class WebSockifyServer():
signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
}
if getattr(signal, 'SIGCHLD', None) is not None:
if getattr(signal, "SIGCHLD", None) is not None:
original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM)
# make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD
if getattr(signal, 'SIGCHLD', None) is not None:
if getattr(signal, "SIGCHLD", None) is not None:
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time
@ -774,8 +832,7 @@ class WebSockifyServer():
time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s'
% self.timeout)
self.msg("listener exit due to --timeout %s" % self.timeout)
break
if self.idle_timeout:
@ -787,8 +844,10 @@ class WebSockifyServer():
last_active_time = time.time()
if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s'
% self.idle_timeout)
self.msg(
"listener exit due to --idle-timeout %s"
% self.idle_timeout
)
break
try:
@ -806,9 +865,9 @@ class WebSockifyServer():
raise
except Exception:
_, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'):
if hasattr(exc, "errno"):
err = exc.errno
elif hasattr(exc, 'args'):
elif hasattr(exc, "args"):
err = exc.args[0]
else:
err = exc[0]
@ -822,14 +881,13 @@ class WebSockifyServer():
# Run in same process if run_once
self.top_new_client(startsock, address)
if self.ws_connection:
self.msg('%s: exiting due to --run-once'
% address[0])
self.msg("%s: exiting due to --run-once" % address[0])
break
else:
self.vmsg('%s: new handler Process' % address[0])
self.vmsg("%s: new handler Process" % address[0])
p = multiprocessing.Process(
target=self.top_new_client,
args=(startsock, address))
target=self.top_new_client, args=(startsock, address)
)
p.start()
# child will not return
@ -857,12 +915,11 @@ class WebSockifyServer():
startsock.close()
finally:
# Close listen port
self.vmsg("Closing socket listening at %s:%s",
self.listen_host, self.listen_port)
self.vmsg(
"Closing socket listening at %s:%s", self.listen_host, self.listen_port
)
lsock.close()
# Restore signals
for sig, func in original_signals.items():
signal.signal(sig, func)