run `ruff format .`

This commit is contained in:
Doctor 2025-03-12 20:55:57 +03:00
parent 4fc7e52352
commit 3eca8ad195
20 changed files with 1488 additions and 1027 deletions

View File

@ -1,11 +1,11 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
version = '0.13.0' version = "0.13.0"
name = 'websockify' name = "websockify"
long_description = open("README.md").read() + "\n" + \ long_description = open("README.md").read() + "\n" + open("CHANGES.txt").read() + "\n"
open("CHANGES.txt").read() + "\n"
setup(name=name, setup(
name=name,
version=version, version=version,
description="Websockify.", description="Websockify.",
long_description=long_description, long_description=long_description,
@ -22,23 +22,23 @@ setup(name=name,
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
], ],
keywords='noVNC websockify', keywords="noVNC websockify",
license='LGPLv3', license="LGPLv3",
url="https://github.com/novnc/websockify", url="https://github.com/novnc/websockify",
author="Joel Martin", author="Joel Martin",
author_email="github@martintribe.org", author_email="github@martintribe.org",
packages=["websockify"],
packages=['websockify'],
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'numpy', 'requests', "numpy",
'jwcrypto', "requests",
'redis', "jwcrypto",
"redis",
], ],
zip_safe=False, zip_safe=False,
entry_points={ entry_points={
'console_scripts': [ "console_scripts": [
'websockify = websockify.websocketproxy:websockify_init', "websockify = websockify.websocketproxy:websockify_init",
] ]
}, },
) )

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
A WebSocket server that echos back whatever it receives from the client. A WebSocket server that echos back whatever it receives from the client.
Copyright 2010 Joel Martin Copyright 2010 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -8,16 +8,19 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using: You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' """
import os, sys, select, optparse, logging import os, sys, select, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketEcho(WebSockifyRequestHandler): class WebSocketEcho(WebSockifyRequestHandler):
""" """
WebSockets server that echos back whatever is received from the WebSockets server that echos back whatever is received from the
client. """ client."""
buffer_size = 8096 buffer_size = 8096
def new_websocket_client(self): def new_websocket_client(self):
@ -33,9 +36,11 @@ class WebSocketEcho(WebSockifyRequestHandler):
while True: while True:
wlist = [] wlist = []
if cqueue or c_pend: wlist.append(self.request) if cqueue or c_pend:
wlist.append(self.request)
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if self.request in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
@ -50,20 +55,27 @@ class WebSocketEcho(WebSockifyRequestHandler):
if closed: if closed:
break break
if __name__ == '__main__':
if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port") parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true", parser.add_option(
help="verbose messages and per frame traffic") "--verbose",
parser.add_option("--cert", default="self.pem", "-v",
help="SSL certificate file") action="store_true",
parser.add_option("--key", default=None, help="verbose messages and per frame traffic",
help="SSL key file (if separate from cert)") )
parser.add_option("--ssl-only", action="store_true", parser.add_option("--cert", default="self.pem", help="SSL certificate file")
help="disallow non-encrypted connections") parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
try: try:
if len(args) != 1: raise ValueError if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
except ValueError: except ValueError:
parser.error("Invalid arguments") parser.error("Invalid arguments")
@ -73,4 +85,3 @@ if __name__ == '__main__':
opts.web = "." opts.web = "."
server = WebSockifyServer(WebSocketEcho, **opts.__dict__) server = WebSockifyServer(WebSocketEcho, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -5,9 +5,12 @@ import sys
import optparse import optparse
import select import select
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocket, \ from websockify.websocket import (
WebSocketWantReadError, WebSocketWantWriteError WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
parser = optparse.OptionParser(usage="%prog URL") parser = optparse.OptionParser(usage="%prog URL")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
@ -22,19 +25,23 @@ print("Connecting to %s..." % URL)
sock.connect(URL) sock.connect(URL)
print("Connected.") print("Connected.")
def send(msg): def send(msg):
while True: while True:
try: try:
sock.sendmsg(msg) sock.sendmsg(msg)
break break
except WebSocketWantReadError: except WebSocketWantReadError:
msg = '' msg = ""
ins, outs, excepts = select.select([sock], [], []) ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError: except WebSocketWantWriteError:
msg = '' msg = ""
ins, outs, excepts = select.select([], [sock], []) ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
def read(): def read():
while True: while True:
@ -42,10 +49,13 @@ def read():
return sock.recvmsg() return sock.recvmsg()
except WebSocketWantReadError: except WebSocketWantReadError:
ins, outs, excepts = select.select([sock], [], []) ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError: except WebSocketWantWriteError:
ins, outs, excepts = select.select([], [sock], []) ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
counter = 1 counter = 1
while True: while True:
@ -56,7 +66,8 @@ while True:
while True: while True:
ins, outs, excepts = select.select([sock], [], [], 1.0) ins, outs, excepts = select.select([sock], [], [], 1.0)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if ins == []: if ins == []:
break break

View File

@ -1,32 +1,32 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
WebSocket server-side load test program. Sends and receives traffic WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted. given a sequence number. Any errors are reported and counted.
''' """
import sys, os, select, random, time, optparse, logging import sys, os, select, random, time, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketLoadServer(WebSockifyServer):
class WebSocketLoadServer(WebSockifyServer):
recv_cnt = 0 recv_cnt = 0
send_cnt = 0 send_cnt = 0
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.delay = kwargs.pop('delay') self.delay = kwargs.pop("delay")
WebSockifyServer.__init__(self, *args, **kwargs) WebSockifyServer.__init__(self, *args, **kwargs)
class WebSocketLoad(WebSockifyRequestHandler): class WebSocketLoad(WebSockifyRequestHandler):
max_packet_size = 10000 max_packet_size = 10000
def new_websocket_client(self): def new_websocket_client(self):
print "Prepopulating random array" print("Prepopulating random array")
self.rand_array = [] self.rand_array = []
for i in range(0, self.max_packet_size): for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9)) self.rand_array.append(random.randint(0, 9))
@ -37,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
self.responder(self.request) self.responder(self.request)
print "accumulated errors:", self.errors print("accumulated errors:", self.errors)
self.errors = 0 self.errors = 0
def responder(self, client): def responder(self, client):
@ -49,7 +49,8 @@ class WebSocketLoad(WebSockifyRequestHandler):
while True: while True:
ins, outs, excepts = select.select(socks, socks, socks, 1) ins, outs, excepts = select.select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if client in ins: if client in ins:
frames, closed = self.recv_frames() frames, closed = self.recv_frames()
@ -57,7 +58,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
err = self.check(frames) err = self.check(frames)
if err: if err:
self.errors = self.errors + 1 self.errors = self.errors + 1
print err print(err)
if closed: if closed:
break break
@ -73,24 +74,22 @@ class WebSocketLoad(WebSockifyRequestHandler):
def generate(self): def generate(self):
length = random.randint(10, self.max_packet_size) length = random.randint(10, self.max_packet_size)
numlist = self.rand_array[self.max_packet_size-length:] numlist = self.rand_array[self.max_packet_size - length :]
# Error in length # Error in length
#numlist.append(5) # numlist.append(5)
chksum = sum(numlist) chksum = sum(numlist)
# Error in checksum # Error in checksum
#numlist[0] = 5 # numlist[0] = 5
nums = "".join( [str(n) for n in numlist] ) nums = "".join([str(n) for n in numlist])
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums) data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
self.send_cnt += 1 self.send_cnt += 1
return data return data
def check(self, frames): def check(self, frames):
err = "" err = ""
for data in frames: for data in frames:
if data.count('$') > 1: if data.count("$") > 1:
raise Exception("Multiple parts within single packet") raise Exception("Multiple parts within single packet")
if len(data) == 0: if len(data) == 0:
self.traffic("_") self.traffic("_")
@ -101,12 +100,12 @@ class WebSocketLoad(WebSockifyRequestHandler):
continue continue
try: try:
cnt, length, chksum, nums = data[1:-1].split(':') cnt, length, chksum, nums = data[1:-1].split(":")
cnt = int(cnt) cnt = int(cnt)
length = int(length) length = int(length)
chksum = int(chksum) chksum = int(chksum)
except ValueError: except ValueError:
print "\n<BOF>" + repr(data) + "<EOF>" print("\n<BOF>" + repr(data) + "<EOF>")
err += "Invalid data format\n" err += "Invalid data format\n"
continue continue
@ -131,27 +130,37 @@ class WebSocketLoad(WebSockifyRequestHandler):
real_chksum += int(num) real_chksum += int(num)
if real_chksum != chksum: if real_chksum != chksum:
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum) err += "Expected checksum %d but real chksum is %d\n" % (
chksum,
real_chksum,
)
return err return err
if __name__ == '__main__': if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port") parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true", parser.add_option(
help="verbose messages and per frame traffic") "--verbose",
parser.add_option("--cert", default="self.pem", "-v",
help="SSL certificate file") action="store_true",
parser.add_option("--key", default=None, help="verbose messages and per frame traffic",
help="SSL key file (if separate from cert)") )
parser.add_option("--ssl-only", action="store_true", parser.add_option("--cert", default="self.pem", help="SSL certificate file")
help="disallow non-encrypted connections") parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
try: try:
if len(args) != 1: raise ValueError if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
if len(args) not in [1,2]: raise ValueError if len(args) not in [1, 2]:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
if len(args) == 2: if len(args) == 2:
opts.delay = int(args[1]) opts.delay = int(args[1])
@ -165,4 +174,3 @@ if __name__ == '__main__':
opts.web = "." opts.web = "."
server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -1,28 +1,33 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4 # vim: tabstop=4 shiftwidth=4 softtabstop=4
""" Unit tests for Authentication plugins""" """Unit tests for Authentication plugins"""
from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError
import unittest import unittest
class BasicHTTPAuthTestCase(unittest.TestCase): class BasicHTTPAuthTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.plugin = BasicHTTPAuth('Aladdin:open sesame') self.plugin = BasicHTTPAuth("Aladdin:open sesame")
def test_no_auth(self): def test_no_auth(self):
headers = {} headers = {}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_invalid_password(self): def test_invalid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'} headers = {"Authorization": "Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0"}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_valid_password(self): def test_valid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='} headers = {"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="}
self.plugin.authenticate(headers, 'localhost', '1234') self.plugin.authenticate(headers, "localhost", "1234")
def test_garbage_auth(self): def test_garbage_auth(self):
headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'} headers = {"Authorization": "Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)

View File

@ -1,80 +1,93 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4 # vim: tabstop=4 shiftwidth=4 softtabstop=4
""" Unit tests for Token plugins""" """Unit tests for Token plugins"""
import sys import sys
import unittest import unittest
from unittest.mock import patch, mock_open, MagicMock from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk from jwcrypto import jwt, jwk
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis from websockify.token_plugins import (
parse_source_args,
ReadOnlyTokenFile,
JWTTokenApi,
TokenRedis,
)
class ParseSourceArgumentsTestCase(unittest.TestCase): class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self): def test_parameterized(self):
params = [ params = [
('', ['']), ("", [""]),
(':', ['', '']), (":", ["", ""]),
('::', ['', '', '']), ("::", ["", "", ""]),
('"', ['"']), ('"', ['"']),
('""', ['""']), ('""', ['""']),
('"""', ['"""']), ('"""', ['"""']),
('"localhost"', ['localhost']), ('"localhost"', ["localhost"]),
('"localhost":', ['localhost', '']), ('"localhost":', ["localhost", ""]),
('"localhost"::', ['localhost', '', '']), ('"localhost"::', ["localhost", "", ""]),
('"local:host"', ['local:host']), ('"local:host"', ["local:host"]),
('"local:host:"pass"', ['"local', 'host', "pass"]), ('"local:host:"pass"', ['"local', "host", "pass"]),
('"local":"host"', ['local', 'host']), ('"local":"host"', ["local", "host"]),
('"local":host"', ['local', 'host"']), ('"local":host"', ["local", 'host"']),
('localhost:6379:1:pass"word:"my-app-namespace:dev"', (
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']), 'localhost:6379:1:pass"word:"my-app-namespace:dev"',
["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"],
),
] ]
for src, args in params: for src, args in params:
self.assertEqual(args, parse_source_args(src)) self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase): class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False)) patch("os.path.isdir", MagicMock(return_value=False))
def test_empty(self): def test_empty(self):
plugin = ReadOnlyTokenFile('configfile') plugin = ReadOnlyTokenFile("configfile")
config = "" config = ""
pyopen = mock_open(read_data=config) pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True): with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile') pyopen.assert_called_once_with("configfile")
self.assertIsNone(result) self.assertIsNone(result)
patch('os.path.isdir', MagicMock(return_value=False)) patch("os.path.isdir", MagicMock(return_value=False))
def test_simple(self): def test_simple(self):
plugin = ReadOnlyTokenFile('configfile') plugin = ReadOnlyTokenFile("configfile")
config = "testhost: remote_host:remote_port" config = "testhost: remote_host:remote_port"
pyopen = mock_open(read_data=config) pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True): with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile') pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host") self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port") self.assertEqual(result[1], "remote_port")
patch('os.path.isdir', MagicMock(return_value=False)) patch("os.path.isdir", MagicMock(return_value=False))
def test_tabs(self): def test_tabs(self):
plugin = ReadOnlyTokenFile('configfile') plugin = ReadOnlyTokenFile("configfile")
config = "testhost:\tremote_host:remote_port" config = "testhost:\tremote_host:remote_port"
pyopen = mock_open(read_data=config) pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True): with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile') pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host") self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port") self.assertEqual(result[1], "remote_port")
class JWSTokenTestCase(unittest.TestCase): class JWSTokenTestCase(unittest.TestCase):
def test_asymmetric_jws_token_plugin(self): def test_asymmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/public.pem") plugin = JWTTokenApi("./tests/fixtures/public.pem")
@ -82,7 +95,9 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
@ -97,21 +112,26 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result) self.assertIsNone(result)
@patch('time.time') @patch("time.time")
def test_jwt_valid_time(self, mock_time): def test_jwt_valid_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem") plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 150 mock_time.return_value = 150
@ -121,14 +141,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host") self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port") self.assertEqual(result[1], "remote_port")
@patch('time.time') @patch("time.time")
def test_jwt_early_time(self, mock_time): def test_jwt_early_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem") plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 50 mock_time.return_value = 50
@ -136,14 +159,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertIsNone(result) self.assertIsNone(result)
@patch('time.time') @patch("time.time")
def test_jwt_late_time(self, mock_time): def test_jwt_late_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem") plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 250 mock_time.return_value = 250
@ -156,8 +182,10 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK() key = jwk.JWK()
key.import_key(kty="oct",k=secret) key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
@ -171,8 +199,10 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK() key = jwk.JWK()
key.import_key(kty="oct",k=secret) key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
@ -188,10 +218,14 @@ class JWSTokenTestCase(unittest.TestCase):
public_key_data = open("./tests/fixtures/public.pem", "rb").read() public_key_data = open("./tests/fixtures/public.pem", "rb").read()
private_key.import_from_pem(private_key_data) private_key.import_from_pem(private_key_data)
public_key.import_from_pem(public_key_data) public_key.import_from_pem(public_key_data)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(private_key) jwt_token.make_signed_token(private_key)
jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"}, jwe_token = jwt.JWT(
claims=jwt_token.serialize()) header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize(),
)
jwe_token.make_encrypted_token(public_key) jwe_token.make_encrypted_token(public_key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
@ -200,103 +234,104 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host") self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port") self.assertEqual(result[1], "remote_port")
class TokenRedisTestCase(unittest.TestCase): class TokenRedisTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
try: try:
import redis import redis
except ImportError: except ImportError:
patcher = patch.dict(sys.modules, {'redis': MagicMock()}) patcher = patch.dict(sys.modules, {"redis": MagicMock()})
patcher.start() patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
@patch('redis.Redis') @patch("redis.Redis")
def test_empty(self, mock_redis): def test_empty(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = None instance.get.return_value = None
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNone(result) self.assertIsNone(result)
@patch('redis.Redis') @patch("redis.Redis")
def test_simple(self, mock_redis): def test_simple(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port"}' instance.get.return_value = b'{"host": "remote_host:remote_port"}'
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
@patch('redis.Redis') @patch("redis.Redis")
def test_json_token_with_spaces(self, mock_redis): def test_json_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = b' {"host": "remote_host:remote_port"} ' instance.get.return_value = b' {"host": "remote_host:remote_port"} '
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
@patch('redis.Redis') @patch("redis.Redis")
def test_text_token(self, mock_redis): def test_text_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = b'remote_host:remote_port' instance.get.return_value = b"remote_host:remote_port"
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
@patch('redis.Redis') @patch("redis.Redis")
def test_text_token_with_spaces(self, mock_redis): def test_text_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = b' remote_host:remote_port ' instance.get.return_value = b" remote_host:remote_port "
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
@patch('redis.Redis') @patch("redis.Redis")
def test_invalid_token(self, mock_redis): def test_invalid_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port" ' instance.get.return_value = b'{"host": "remote_host:remote_port" '
result = plugin.lookup('testhost') result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost') instance.get.assert_called_once_with("testhost")
self.assertIsNone(result) self.assertIsNone(result)
@patch('redis.Redis') @patch("redis.Redis")
def test_token_without_namespace(self, mock_redis): def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
token = 'testhost' token = "testhost"
def mock_redis_get(key): def mock_redis_get(key):
self.assertEqual(key, token) self.assertEqual(key, token)
return b'remote_host:remote_port' return b"remote_host:remote_port"
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get = mock_redis_get instance.get = mock_redis_get
@ -304,17 +339,17 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token) result = plugin.lookup(token)
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
@patch('redis.Redis') @patch("redis.Redis")
def test_token_with_namespace(self, mock_redis): def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234:::namespace') plugin = TokenRedis("127.0.0.1:1234:::namespace")
token = 'testhost' token = "testhost"
def mock_redis_get(key): def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token) self.assertEqual(key, "namespace:" + token)
return b'remote_host:remote_port' return b"remote_host:remote_port"
instance = mock_redis.return_value instance = mock_redis.return_value
instance.get = mock_redis_get instance.get = mock_redis_get
@ -322,103 +357,103 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token) result = plugin.lookup(token)
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host') self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], 'remote_port') self.assertEqual(result[1], "remote_port")
def test_src_only_host(self): def test_src_only_host(self):
plugin = TokenRedis('127.0.0.1') plugin = TokenRedis("127.0.0.1")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_port(self): def test_src_with_host_port(self):
plugin = TokenRedis('127.0.0.1:1234') plugin = TokenRedis("127.0.0.1:1234")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db(self): def test_src_with_host_port_db(self):
plugin = TokenRedis('127.0.0.1:1234:2') plugin = TokenRedis("127.0.0.1:1234:2")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass(self): def test_src_with_host_port_db_pass(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret') plugin = TokenRedis("127.0.0.1:1234:2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret') self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass_namespace(self): def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace') plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret') self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "namespace:") self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self): def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::verysecret') plugin = TokenRedis("127.0.0.1:::verysecret")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret') self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self): def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis('127.0.0.1::::') plugin = TokenRedis("127.0.0.1::::")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self): def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::') plugin = TokenRedis("127.0.0.1:::")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self): def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::') plugin = TokenRedis("127.0.0.1::")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self): def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:') plugin = TokenRedis("127.0.0.1:")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self): def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis('127.0.0.1::::namespace') plugin = TokenRedis("127.0.0.1::::namespace")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
@ -427,43 +462,43 @@ class TokenRedisTestCase(unittest.TestCase):
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self): def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"') plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:") self.assertEqual(plugin._namespace, "ns1:ns2:")
def test_src_with_host_empty_port_db_no_pass_no_namespace(self): def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2') plugin = TokenRedis("127.0.0.1::2")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_empty_db_pass_no_namespace(self): def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:1234::verysecret') plugin = TokenRedis("127.0.0.1:1234::verysecret")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0) self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret') self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_pass_no_namespace(self): def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:verysecret') plugin = TokenRedis("127.0.0.1::2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret') self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "") self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self): def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:') plugin = TokenRedis("127.0.0.1::2:")
self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2) self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None) self.assertEqual(plugin._password, None)

View File

@ -14,199 +14,268 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" Unit tests for websocket """ """Unit tests for websocket"""
import unittest import unittest
from websockify import websocket from websockify import websocket
class FakeSocket: class FakeSocket:
def __init__(self): def __init__(self):
self.data = b'' self.data = b""
def send(self, buf): def send(self, buf):
self.data += buf self.data += buf
return len(buf) return len(buf)
class AcceptTestCase(unittest.TestCase): class AcceptTestCase(unittest.TestCase):
def test_success(self): def test_success(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket', ws.accept(
'Sec-WebSocket-Version': '13', sock,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) {
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') "upgrade": "websocket",
self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data) "Sec-WebSocket-Version": "13",
self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data) "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data) },
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nUpgrade: websocket\r\n" in sock.data)
self.assertTrue(b"\r\nConnection: Upgrade\r\n" in sock.data)
self.assertTrue(
b"\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n" in sock.data
)
def test_bad_version(self): def test_bad_version(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
self.assertRaises(Exception, ws.accept, self.assertRaises(
sock, {'upgrade': 'websocket', Exception,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) ws.accept,
self.assertRaises(Exception, ws.accept, sock,
sock, {'upgrade': 'websocket', {"upgrade": "websocket", "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q=="},
'Sec-WebSocket-Version': '5', )
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) self.assertRaises(
self.assertRaises(Exception, ws.accept, Exception,
sock, {'upgrade': 'websocket', ws.accept,
'Sec-WebSocket-Version': '20', sock,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) {
"upgrade": "websocket",
"Sec-WebSocket-Version": "5",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "20",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_bad_upgrade(self): def test_bad_upgrade(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
self.assertRaises(Exception, ws.accept, self.assertRaises(
sock, {'Sec-WebSocket-Version': '13', Exception,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) ws.accept,
self.assertRaises(Exception, ws.accept, sock,
sock, {'upgrade': 'websocket2', {
'Sec-WebSocket-Version': '13', "Sec-WebSocket-Version": "13",
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket2",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_missing_key(self): def test_missing_key(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
self.assertRaises(Exception, ws.accept, self.assertRaises(
sock, {'upgrade': 'websocket', Exception,
'Sec-WebSocket-Version': '13'}) ws.accept,
sock,
{"upgrade": "websocket", "Sec-WebSocket-Version": "13"},
)
def test_protocol(self): def test_protocol(self):
class ProtoSocket(websocket.WebSocket): class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol): def select_subprotocol(self, protocol):
return 'gazonk' return "gazonk"
ws = ProtoSocket() ws = ProtoSocket()
sock = FakeSocket() sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket', ws.accept(
'Sec-WebSocket-Version': '13', sock,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', {
'Sec-WebSocket-Protocol': 'foobar gazonk'}) "upgrade": "websocket",
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') "Sec-WebSocket-Version": "13",
self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data) "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nSec-WebSocket-Protocol: gazonk\r\n" in sock.data)
def test_no_protocol(self): def test_no_protocol(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket', ws.accept(
'Sec-WebSocket-Version': '13', sock,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) {
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') "upgrade": "websocket",
self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data) "Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertFalse(b"\r\nSec-WebSocket-Protocol:" in sock.data)
def test_missing_protocol(self): def test_missing_protocol(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
sock = FakeSocket() sock = FakeSocket()
self.assertRaises(Exception, ws.accept, self.assertRaises(
sock, {'upgrade': 'websocket', Exception,
'Sec-WebSocket-Version': '13', ws.accept,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', sock,
'Sec-WebSocket-Protocol': 'foobar gazonk'}) {
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
def test_protocol(self): def test_protocol(self):
class ProtoSocket(websocket.WebSocket): class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol): def select_subprotocol(self, protocol):
return 'oddball' return "oddball"
ws = ProtoSocket() ws = ProtoSocket()
sock = FakeSocket() sock = FakeSocket()
self.assertRaises(Exception, ws.accept, self.assertRaises(
sock, {'upgrade': 'websocket', Exception,
'Sec-WebSocket-Version': '13', ws.accept,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', sock,
'Sec-WebSocket-Protocol': 'foobar gazonk'}) {
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
class PingPongTest(unittest.TestCase): class PingPongTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.ws = websocket.WebSocket() self.ws = websocket.WebSocket()
self.sock = FakeSocket() self.sock = FakeSocket()
self.ws.accept(self.sock, {'upgrade': 'websocket', self.ws.accept(
'Sec-WebSocket-Version': '13', self.sock,
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) {
self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ') "upgrade": "websocket",
self.sock.data = b'' "Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(self.sock.data[:13], b"HTTP/1.1 101 ")
self.sock.data = b""
def test_ping(self): def test_ping(self):
self.ws.ping() self.ws.ping()
self.assertEqual(self.sock.data, b'\x89\x00') self.assertEqual(self.sock.data, b"\x89\x00")
def test_pong(self): def test_pong(self):
self.ws.pong() self.ws.pong()
self.assertEqual(self.sock.data, b'\x8a\x00') self.assertEqual(self.sock.data, b"\x8a\x00")
def test_ping_data(self): def test_ping_data(self):
self.ws.ping(b'foo') self.ws.ping(b"foo")
self.assertEqual(self.sock.data, b'\x89\x03foo') self.assertEqual(self.sock.data, b"\x89\x03foo")
def test_pong_data(self): def test_pong_data(self):
self.ws.pong(b'foo') self.ws.pong(b"foo")
self.assertEqual(self.sock.data, b'\x8a\x03foo') self.assertEqual(self.sock.data, b"\x8a\x03foo")
class HyBiEncodeDecodeTestCase(unittest.TestCase): class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self): def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' buf = b"\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58"
ws = websocket.WebSocket() ws = websocket.WebSocket()
res = ws._decode_hybi(buf) res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1) self.assertEqual(res["fin"], 1)
self.assertEqual(res['opcode'], 0x1) self.assertEqual(res["opcode"], 0x1)
self.assertEqual(res['masked'], True) self.assertEqual(res["masked"], True)
self.assertEqual(res['length'], len(buf)) self.assertEqual(res["length"], len(buf))
self.assertEqual(res['payload'], b'Hello') self.assertEqual(res["payload"], b"Hello")
def test_decode_hybi_binary(self): def test_decode_hybi_binary(self):
buf = b'\x82\x04\x01\x02\x03\x04' buf = b"\x82\x04\x01\x02\x03\x04"
ws = websocket.WebSocket() ws = websocket.WebSocket()
res = ws._decode_hybi(buf) res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1) self.assertEqual(res["fin"], 1)
self.assertEqual(res['opcode'], 0x2) self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res['length'], len(buf)) self.assertEqual(res["length"], len(buf))
self.assertEqual(res['payload'], b'\x01\x02\x03\x04') self.assertEqual(res["payload"], b"\x01\x02\x03\x04")
def test_decode_hybi_extended_16bit_binary(self): def test_decode_hybi_extended_16bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b'\x82\x7e\x01\x04' + data buf = b"\x82\x7e\x01\x04" + data
ws = websocket.WebSocket() ws = websocket.WebSocket()
res = ws._decode_hybi(buf) res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1) self.assertEqual(res["fin"], 1)
self.assertEqual(res['opcode'], 0x2) self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res['length'], len(buf)) self.assertEqual(res["length"], len(buf))
self.assertEqual(res['payload'], data) self.assertEqual(res["payload"], data)
def test_decode_hybi_extended_64bit_binary(self): def test_decode_hybi_extended_64bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data buf = b"\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04" + data
ws = websocket.WebSocket() ws = websocket.WebSocket()
res = ws._decode_hybi(buf) res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1) self.assertEqual(res["fin"], 1)
self.assertEqual(res['opcode'], 0x2) self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res['length'], len(buf)) self.assertEqual(res["length"], len(buf))
self.assertEqual(res['payload'], data) self.assertEqual(res["payload"], data)
def test_decode_hybi_multi(self): def test_decode_hybi_multi(self):
buf1 = b'\x01\x03\x48\x65\x6c' buf1 = b"\x01\x03\x48\x65\x6c"
buf2 = b'\x80\x02\x6c\x6f' buf2 = b"\x80\x02\x6c\x6f"
ws = websocket.WebSocket() ws = websocket.WebSocket()
res1 = ws._decode_hybi(buf1) res1 = ws._decode_hybi(buf1)
self.assertEqual(res1['fin'], 0) self.assertEqual(res1["fin"], 0)
self.assertEqual(res1['opcode'], 0x1) self.assertEqual(res1["opcode"], 0x1)
self.assertEqual(res1['length'], len(buf1)) self.assertEqual(res1["length"], len(buf1))
self.assertEqual(res1['payload'], b'Hel') self.assertEqual(res1["payload"], b"Hel")
res2 = ws._decode_hybi(buf2) res2 = ws._decode_hybi(buf2)
self.assertEqual(res2['fin'], 1) self.assertEqual(res2["fin"], 1)
self.assertEqual(res2['opcode'], 0x0) self.assertEqual(res2["opcode"], 0x0)
self.assertEqual(res2['length'], len(buf2)) self.assertEqual(res2["length"], len(buf2))
self.assertEqual(res2['payload'], b'lo') self.assertEqual(res2["payload"], b"lo")
def test_encode_hybi_basic(self): def test_encode_hybi_basic(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
res = ws._encode_hybi(0x1, b'Hello') res = ws._encode_hybi(0x1, b"Hello")
expected = b'\x81\x05\x48\x65\x6c\x6c\x6f' expected = b"\x81\x05\x48\x65\x6c\x6c\x6f"
self.assertEqual(res, expected) self.assertEqual(res, expected)

View File

@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" Unit tests for websocketproxy """ """Unit tests for websocketproxy"""
import sys import sys
import unittest import unittest
@ -30,7 +30,7 @@ from websockify import auth_plugins
class FakeSocket: class FakeSocket:
def __init__(self, data=b''): def __init__(self, data=b""):
self._data = data self._data = data
def recv(self, amt, flags=None): def recv(self, amt, flags=None):
@ -40,11 +40,11 @@ class FakeSocket:
return res return res
def makefile(self, mode='r', buffsize=None): def makefile(self, mode="r", buffsize=None):
if 'b' in mode: if "b" in mode:
return BytesIO(self._data) return BytesIO(self._data)
else: else:
return StringIO(self._data.decode('latin_1')) return StringIO(self._data.decode("latin_1"))
class FakeServer: class FakeServer:
@ -58,14 +58,16 @@ class FakeServer:
self.ssl_target = None self.ssl_target = None
self.unix_target = None self.unix_target = None
class ProxyRequestHandlerTestCase(unittest.TestCase): class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.handler = websocketproxy.ProxyRequestHandler( self.handler = websocketproxy.ProxyRequestHandler(
FakeSocket(), "127.0.0.1", FakeServer()) FakeSocket(), "127.0.0.1", FakeServer()
)
self.handler.path = "https://localhost:6080/websockify?token=blah" self.handler.path = "https://localhost:6080/websockify?token=blah"
self.handler.headers = {} self.handler.headers = {}
patch('websockify.websockifyserver.WebSockifyServer.socket').start() patch("websockify.websockifyserver.WebSockifyServer.socket").start()
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
@ -76,8 +78,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token): def lookup(self, token):
return ("some host", "some port") return ("some host", "some port")
host, port = self.handler.get_target( host, port = self.handler.get_target(TestPlugin(None))
TestPlugin(None))
self.assertEqual(host, "some host") self.assertEqual(host, "some host")
self.assertEqual(port, "some port") self.assertEqual(port, "some port")
@ -87,8 +88,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token): def lookup(self, token):
return ("unix_socket", "/tmp/socket") return ("unix_socket", "/tmp/socket")
_, socket = self.handler.get_target( _, socket = self.handler.get_target(TestPlugin(None))
TestPlugin(None))
self.assertEqual(socket, "/tmp/socket") self.assertEqual(socket, "/tmp/socket")
@ -100,11 +100,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
with self.assertRaises(FakeServer.EClose): with self.assertRaises(FakeServer.EClose):
self.handler.get_target(TestPlugin(None)) self.handler.get_target(TestPlugin(None))
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock()) @patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_token_plugin(self): def test_token_plugin(self):
class TestPlugin(token_plugins.BasePlugin): class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token): def lookup(self, token):
return (self.source + token).split(',') return (self.source + token).split(",")
self.handler.server.token_plugin = TestPlugin("somehost,") self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.validate_connection() self.handler.validate_connection()
@ -112,7 +112,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.assertEqual(self.handler.server.target_host, "somehost") self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah") self.assertEqual(self.handler.server.target_port, "blah")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock()) @patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_auth_plugin(self): def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin): class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port): def authenticate(self, headers, target_host, target_port):
@ -128,4 +128,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.server.target_host = "someotherhost" self.handler.server.target_host = "someotherhost"
self.handler.auth_connection() self.handler.auth_connection()

View File

@ -1,5 +1,5 @@
"""Unit tests for websocketserver"""
""" Unit tests for websocketserver """
import unittest import unittest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase):
# Then # Then
req_obj.end_headers.assert_called_once_with() req_obj.end_headers.assert_called_once_with()

View File

@ -14,7 +14,8 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" Unit tests for websockifyserver """ """Unit tests for websockifyserver"""
import errno import errno
import os import os
import logging import logging
@ -36,11 +37,11 @@ from websockify import websockifyserver
def raise_oserror(*args, **kwargs): def raise_oserror(*args, **kwargs):
raise OSError('fake error') raise OSError("fake error")
class FakeSocket: class FakeSocket:
def __init__(self, data=b''): def __init__(self, data=b""):
self._data = data self._data = data
def recv(self, amt, flags=None): def recv(self, amt, flags=None):
@ -50,19 +51,19 @@ class FakeSocket:
return res return res
def makefile(self, mode='r', buffsize=None): def makefile(self, mode="r", buffsize=None):
if 'b' in mode: if "b" in mode:
return BytesIO(self._data) return BytesIO(self._data)
else: else:
return StringIO(self._data.decode('latin_1')) return StringIO(self._data.decode("latin_1"))
class WebSockifyRequestHandlerTestCase(unittest.TestCase): class WebSockifyRequestHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests') self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up # Mock this out cause it screws tests up
patch('os.chdir').start() patch("os.chdir").start()
def tearDown(self): def tearDown(self):
"""Called automatically after each test.""" """Called automatically after each test."""
@ -70,31 +71,41 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir) os.rmdir(self.tmpdir)
super().tearDown() super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, def _get_server(
**kwargs): self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
web = kwargs.pop('web', self.tmpdir) ):
web = kwargs.pop("web", self.tmpdir)
return websockifyserver.WebSockifyServer( return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost', handler_class,
listen_port=80, key=self.tmpdir, web=web, listen_host="localhost",
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, listen_port=80,
**kwargs) key=self.tmpdir,
web=web,
record=self.tmpdir,
daemon=False,
ssl_only=0,
idle_timeout=1,
**kwargs,
)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error') @patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_normal_get_with_only_upgrade_returns_error(self, send_error): def test_normal_get_with_only_upgrade_returns_error(self, send_error):
server = self._get_server(web=None) server = self._get_server(web=None)
handler = websockifyserver.WebSockifyRequestHandler( handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) FakeSocket(b"GET /tmp.txt HTTP/1.1"), "127.0.0.1", server
)
handler.do_GET() handler.do_GET()
send_error.assert_called_with(405) send_error.assert_called_with(405)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error') @patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_list_dir_with_file_only_returns_error(self, send_error): def test_list_dir_with_file_only_returns_error(self, send_error):
server = self._get_server(file_only=True) server = self._get_server(file_only=True)
handler = websockifyserver.WebSockifyRequestHandler( handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server) FakeSocket(b"GET / HTTP/1.1"), "127.0.0.1", server
)
handler.path = '/' handler.path = "/"
handler.do_GET() handler.do_GET()
send_error.assert_called_with(404) send_error.assert_called_with(404)
@ -102,9 +113,9 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
class WebSockifyServerTestCase(unittest.TestCase): class WebSockifyServerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests') self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up # Mock this out cause it screws tests up
patch('os.chdir').start() patch("os.chdir").start()
def tearDown(self): def tearDown(self):
"""Called automatically after each test.""" """Called automatically after each test."""
@ -112,32 +123,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir) os.rmdir(self.tmpdir)
super().tearDown() super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, def _get_server(
**kwargs): self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
):
return websockifyserver.WebSockifyServer( return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost', handler_class,
listen_port=80, key=self.tmpdir, web=self.tmpdir, listen_host="localhost",
record=self.tmpdir, **kwargs) listen_port=80,
key=self.tmpdir,
web=self.tmpdir,
record=self.tmpdir,
**kwargs,
)
def test_daemonize_raises_error_while_closing_fds(self): def test_daemonize_raises_error_while_closing_fds(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0 patch("os.fork").start().return_value = 0
patch('signal.signal').start() patch("signal.signal").start()
patch('os.setsid').start() patch("os.setsid").start()
patch('os.close').start().side_effect = raise_oserror patch("os.close").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_daemonize_ignores_ebadf_error_while_closing_fds(self): def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
def raise_oserror_ebadf(fd): def raise_oserror_ebadf(fd):
raise OSError(errno.EBADF, 'fake error') raise OSError(errno.EBADF, "fake error")
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0 patch("os.fork").start().return_value = 0
patch('signal.signal').start() patch("signal.signal").start()
patch('os.setsid').start() patch("os.setsid").start()
patch('os.close').start().side_effect = raise_oserror_ebadf patch("os.close").start().side_effect = raise_oserror_ebadf
patch('os.open').start().side_effect = raise_oserror patch("os.open").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_handshake_fails_on_not_ready(self): def test_handshake_fails_on_not_ready(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -145,23 +162,29 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([], [], []) return ([], [], [])
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
self.assertRaises( self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake, websockifyserver.WebSockifyServer.EClose,
FakeSocket(), '127.0.0.1') server.do_handshake,
FakeSocket(),
"127.0.0.1",
)
def test_empty_handshake_fails(self): def test_empty_handshake_fails(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket('') sock = FakeSocket("")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], []) return ([sock], [], [])
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
self.assertRaises( self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake, websockifyserver.WebSockifyServer.EClose,
sock, '127.0.0.1') server.do_handshake,
sock,
"127.0.0.1",
)
def test_handshake_policy_request(self): def test_handshake_policy_request(self):
# TODO(directxman12): implement # TODO(directxman12): implement
@ -170,35 +193,39 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_handshake_ssl_only_without_ssl_raises_error(self): def test_handshake_ssl_only_without_ssl_raises_error(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
sock = FakeSocket(b'some initial data') sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], []) return ([sock], [], [])
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
self.assertRaises( self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake, websockifyserver.WebSockifyServer.EClose,
sock, '127.0.0.1') server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_no_ssl(self): def test_do_handshake_no_ssl(self):
class FakeHandler: class FakeHandler:
CALLED = False CALLED = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
type(self).CALLED = True type(self).CALLED = True
FakeHandler.CALLED = False FakeHandler.CALLED = False
server = self._get_server( server = self._get_server(
handler_class=FakeHandler, daemon=True, handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1
ssl_only=0, idle_timeout=1) )
sock = FakeSocket(b'some initial data') sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], []) return ([sock], [], [])
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) self.assertEqual(server.do_handshake(sock, "127.0.0.1"), sock)
self.assertTrue(FakeHandler.CALLED, True) self.assertTrue(FakeHandler.CALLED, True)
def test_do_handshake_ssl(self): def test_do_handshake_ssl(self):
@ -210,18 +237,22 @@ class WebSockifyServerTestCase(unittest.TestCase):
pass pass
def test_do_handshake_ssl_without_cert_raises_error(self): def test_do_handshake_ssl_without_cert_raises_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, server = self._get_server(
cert='afdsfasdafdsafdsafdsafdas') daemon=True, ssl_only=0, idle_timeout=1, cert="afdsfasdafdsafdsafdsafdas"
)
sock = FakeSocket(b"\x16some ssl data") sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], []) return ([sock], [], [])
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
self.assertRaises( self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake, websockifyserver.WebSockifyServer.EClose,
sock, '127.0.0.1') server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_error_eof_raises_close_error(self): def test_do_handshake_ssl_error_eof_raises_close_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -234,58 +265,79 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_wrap_socket(*args, **kwargs): def fake_wrap_socket(*args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF) raise ssl.SSLError(ssl.SSL_ERROR_EOF)
class fake_create_default_context(): class fake_create_default_context:
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self.options = 0 self.options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF) raise ssl.SSLError(ssl.SSL_ERROR_EOF)
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
self.assertRaises( self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake, websockifyserver.WebSockifyServer.EClose,
sock, '127.0.0.1') server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_sets_ciphers(self): def test_do_handshake_ssl_sets_ciphers(self):
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2' test_ciphers = "TEST-CIPHERS-1:TEST-CIPHER-2"
class FakeHandler: class FakeHandler:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
server = self._get_server(handler_class=FakeHandler, daemon=True, server = self._get_server(
idle_timeout=1, ssl_ciphers=test_ciphers) handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_ciphers=test_ciphers,
)
sock = FakeSocket(b"\x16some ssl data") sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], []) return ([sock], [], [])
class fake_create_default_context(): class fake_create_default_context:
CIPHERS = '' CIPHERS = ""
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self.options = 0 self.options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
pass pass
def set_ciphers(self, ciphers_to_set): def set_ciphers(self, ciphers_to_set):
fake_create_default_context.CIPHERS = ciphers_to_set fake_create_default_context.CIPHERS = ciphers_to_set
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context patch(
server.do_handshake(sock, '127.0.0.1') "ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
def test_do_handshake_ssl_sets_opions(self): def test_do_handshake_ssl_sets_opions(self):
@ -295,8 +347,12 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
server = self._get_server(handler_class=FakeHandler, daemon=True, server = self._get_server(
idle_timeout=1, ssl_options=test_options) handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_options=test_options,
)
sock = FakeSocket(b"\x16some ssl data") sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
@ -304,26 +360,36 @@ class WebSockifyServerTestCase(unittest.TestCase):
class fake_create_default_context: class fake_create_default_context:
OPTIONS = 0 OPTIONS = 0
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self._options = 0 self._options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
pass pass
def get_options(self): def get_options(self):
return self._options return self._options
def set_options(self, val): def set_options(self, val):
fake_create_default_context.OPTIONS = val fake_create_default_context.OPTIONS = val
options = property(get_options, set_options) options = property(get_options, set_options)
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context patch(
server.do_handshake(sock, '127.0.0.1') "ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.OPTIONS, test_options) self.assertEqual(fake_create_default_context.OPTIONS, test_options)
def test_fallback_sigchld_handler(self): def test_fallback_sigchld_handler(self):
@ -332,38 +398,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_error(self): def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost') sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error") raise Exception("fake error")
patch('websockify.websockifyserver.WebSockifyServer.socket').start() patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
server.start_server() server.start_server()
def test_start_server_keyboardinterrupt(self): def test_start_server_keyboardinterrupt(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost') sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt raise KeyboardInterrupt
patch('websockify.websockifyserver.WebSockifyServer.socket').start() patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
server.start_server() server.start_server()
def test_start_server_systemexit(self): def test_start_server_systemexit(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost') sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit() sys.exit()
patch('websockify.websockifyserver.WebSockifyServer.socket').start() patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start() patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch('select.select').start().side_effect = fake_select patch("select.select").start().side_effect = fake_select
server.start_server() server.start_server()
def test_socket_set_keepalive_options(self): def test_socket_set_keepalive_options(self):
@ -372,29 +438,37 @@ class WebSockifyServerTestCase(unittest.TestCase):
keepintvl = 56 keepintvl = 56
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost', sock = server.socket(
"localhost",
tcp_keepcnt=keepcnt, tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle, tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl) tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'): if hasattr(socket, "TCP_KEEPCNT"):
self.assertEqual(sock.getsockopt(socket.SOL_TCP, self.assertEqual(
socket.TCP_KEEPCNT), keepcnt) sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
self.assertEqual(sock.getsockopt(socket.SOL_TCP, )
socket.TCP_KEEPIDLE), keepidle) self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(sock.getsockopt(socket.SOL_TCP, self.assertEqual(
socket.TCP_KEEPINTVL), keepintvl) sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)
sock = server.socket('localhost', sock = server.socket(
"localhost",
tcp_keepalive=False, tcp_keepalive=False,
tcp_keepcnt=keepcnt, tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle, tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl) tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'): if hasattr(socket, "TCP_KEEPCNT"):
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, self.assertNotEqual(
socket.TCP_KEEPCNT), keepcnt) sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, )
socket.TCP_KEEPIDLE), keepidle) self.assertNotEqual(
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle
socket.TCP_KEEPINTVL), keepintvl) )
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)

View File

@ -1,4 +1,4 @@
import websockify import websockify
if __name__ == '__main__': if __name__ == "__main__":
websockify.websocketproxy.websockify_init() websockify.websocketproxy.websockify_init()

View File

@ -1,4 +1,4 @@
class BasePlugin(): class BasePlugin:
def __init__(self, src=None): def __init__(self, src=None):
self.source = src self.source = src
@ -7,7 +7,9 @@ class BasePlugin():
class AuthenticationError(Exception): class AuthenticationError(Exception):
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None): def __init__(
self, log_msg=None, response_code=403, response_headers={}, response_msg=None
):
self.code = response_code self.code = response_code
self.headers = response_headers self.headers = response_headers
self.msg = response_msg self.msg = response_msg
@ -15,7 +17,7 @@ class AuthenticationError(Exception):
if log_msg is None: if log_msg is None:
log_msg = response_msg log_msg = response_msg
super().__init__('%s %s' % (self.code, log_msg)) super().__init__("%s %s" % (self.code, log_msg))
class InvalidOriginError(AuthenticationError): class InvalidOriginError(AuthenticationError):
@ -24,12 +26,13 @@ class InvalidOriginError(AuthenticationError):
self.actual_origin = actual self.actual_origin = actual
super().__init__( super().__init__(
response_msg='Invalid Origin', response_msg="Invalid Origin",
log_msg="Invalid Origin Header: Expected one of " log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual)) "%s, got '%s'" % (expected, actual),
)
class BasicHTTPAuth(): class BasicHTTPAuth:
"""Verifies Basic Auth headers. Specify src as username:password""" """Verifies Basic Auth headers. Specify src as username:password"""
def __init__(self, src=None): def __init__(self, src=None):
@ -37,9 +40,10 @@ class BasicHTTPAuth():
def authenticate(self, headers, target_host, target_port): def authenticate(self, headers, target_host, target_port):
import base64 import base64
auth_header = headers.get('Authorization')
auth_header = headers.get("Authorization")
if auth_header: if auth_header:
if not auth_header.startswith('Basic '): if not auth_header.startswith("Basic "):
self.auth_error() self.auth_error()
try: try:
@ -49,11 +53,11 @@ class BasicHTTPAuth():
try: try:
# http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication # http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication
user_pass_as_text = user_pass_raw.decode('ISO-8859-1') user_pass_as_text = user_pass_raw.decode("ISO-8859-1")
except UnicodeDecodeError: except UnicodeDecodeError:
self.auth_error() self.auth_error()
user_pass = user_pass_as_text.split(':', 1) user_pass = user_pass_as_text.split(":", 1)
if len(user_pass) != 2: if len(user_pass) != 2:
self.auth_error() self.auth_error()
@ -64,7 +68,7 @@ class BasicHTTPAuth():
self.demand_auth() self.demand_auth()
def validate_creds(self, username, password): def validate_creds(self, username, password):
if '%s:%s' % (username, password) == self.src: if "%s:%s" % (username, password) == self.src:
return True return True
else: else:
return False return False
@ -73,10 +77,13 @@ class BasicHTTPAuth():
raise AuthenticationError(response_code=403) raise AuthenticationError(response_code=403)
def demand_auth(self): def demand_auth(self):
raise AuthenticationError(response_code=401, raise AuthenticationError(
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) response_code=401,
response_headers={"WWW-Authenticate": 'Basic realm="Websockify"'},
)
class ExpectOrigin():
class ExpectOrigin:
def __init__(self, src=None): def __init__(self, src=None):
if src is None: if src is None:
self.source = [] self.source = []
@ -84,11 +91,12 @@ class ExpectOrigin():
self.source = src.split() self.source = src.split()
def authenticate(self, headers, target_host, target_port): def authenticate(self, headers, target_host, target_port):
origin = headers.get('Origin', None) origin = headers.get("Origin", None)
if origin is None or origin not in self.source: if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin) raise InvalidOriginError(expected=self.source, actual=origin)
class ClientCertCNAuth():
class ClientCertCNAuth:
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names.""" """Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""
def __init__(self, src=None): def __init__(self, src=None):
@ -98,5 +106,5 @@ class ClientCertCNAuth():
self.source = src.split() self.source = src.split()
def authenticate(self, headers, target_host, target_port): def authenticate(self, headers, target_host, target_port):
if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source: if headers.get("SSL_CLIENT_S_DN_CN", None) not in self.source:
raise AuthenticationError(response_code=403) raise AuthenticationError(response_code=403)

View File

@ -7,23 +7,26 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
as defined by RFC 5424. as defined by RFC 5424.
""" """
_legacy_head_fmt = '<{pri}>{ident}[{pid}]: ' _legacy_head_fmt = "<{pri}>{ident}[{pid}]: "
_rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - ' _rfc5424_head_fmt = "<{pri}>1 {timestamp} {hostname} {ident} {pid} - - "
_head_fmt = _rfc5424_head_fmt _head_fmt = _rfc5424_head_fmt
_legacy = False _legacy = False
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ' _timestamp_fmt = "%Y-%m-%dT%H:%M:%SZ"
_max_hostname = 255 _max_hostname = 255
_max_ident = 24 #safer for old daemons _max_ident = 24 # safer for old daemons
_send_length = False _send_length = False
_tail = '\n' _tail = "\n"
ident = None ident = None
def __init__(
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT), self,
address=("localhost", handlers.SYSLOG_UDP_PORT),
facility=handlers.SysLogHandler.LOG_USER, facility=handlers.SysLogHandler.LOG_USER,
socktype=None, ident=None, legacy=False): socktype=None,
ident=None,
legacy=False,
):
""" """
Initialize a handler. Initialize a handler.
@ -46,7 +49,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
super().__init__(address, facility, socktype) super().__init__(address, facility, socktype)
def emit(self, record): def emit(self, record):
""" """
Emit a record. Emit a record.
@ -57,46 +59,44 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
try: try:
# Gather info. # Gather info.
text = self.format(record).replace(self._tail, ' ') text = self.format(record).replace(self._tail, " ")
if not text: # nothing to log if not text: # nothing to log
return return
pri = self.encodePriority(self.facility, pri = self.encodePriority(self.facility, self.mapPriority(record.levelname))
self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime()); timestamp = time.strftime(self._timestamp_fmt, time.gmtime())
hostname = socket.gethostname()[: self._max_hostname]
hostname = socket.gethostname()[:self._max_hostname]
if self.ident: if self.ident:
ident = self.ident[:self._max_ident] ident = self.ident[: self._max_ident]
else: else:
ident = '' ident = ""
pid = os.getpid() # shouldn't need truncation pid = os.getpid() # shouldn't need truncation
# Format the header. # Format the header.
head = { head = {
'pri': pri, "pri": pri,
'timestamp': timestamp, "timestamp": timestamp,
'hostname': hostname, "hostname": hostname,
'ident': ident, "ident": ident,
'pid': pid, "pid": pid,
} }
msg = self._head_fmt.format(**head).encode('ascii', 'ignore') msg = self._head_fmt.format(**head).encode("ascii", "ignore")
# Encode text as plain ASCII if possible, else use UTF-8 with BOM. # Encode text as plain ASCII if possible, else use UTF-8 with BOM.
try: try:
msg += text.encode('ascii') msg += text.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
msg += text.encode('utf-8-sig') msg += text.encode("utf-8-sig")
# Add length or tail character, if necessary. # Add length or tail character, if necessary.
if self.socktype != socket.SOCK_DGRAM: if self.socktype != socket.SOCK_DGRAM:
if self._send_length: if self._send_length:
msg = ('%d ' % len(msg)).encode('ascii') + msg msg = ("%d " % len(msg)).encode("ascii") + msg
else: else:
msg += self._tail.encode('ascii') msg += self._tail.encode("ascii")
# Send the message. # Send the message.
if self.unixsocket: if self.unixsocket:

View File

@ -10,8 +10,8 @@ logger = logging.getLogger(__name__)
_SOURCE_SPLIT_REGEX = re.compile( _SOURCE_SPLIT_REGEX = re.compile(
r'(?<=^)"([^"]+)"(?=:|$)' r'(?<=^)"([^"]+)"(?=:|$)'
r'|(?<=:)"([^"]+)"(?=:|$)' r'|(?<=:)"([^"]+)"(?=:|$)'
r'|(?<=^)([^:]*)(?=:|$)' r"|(?<=^)([^:]*)(?=:|$)"
r'|(?<=:)([^:]*)(?=:|$)', r"|(?<=:)([^:]*)(?=:|$)",
) )
@ -26,7 +26,7 @@ def parse_source_args(src):
return [m[0] or m[1] or m[2] or m[3] for m in matches] return [m[0] or m[1] or m[2] or m[3] for m in matches]
class BasePlugin(): class BasePlugin:
def __init__(self, src): def __init__(self, src):
self.source = src self.source = src
@ -44,8 +44,7 @@ class ReadOnlyTokenFile(BasePlugin):
def _load_targets(self): def _load_targets(self):
if os.path.isdir(self.source): if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for cfg_files = [os.path.join(self.source, f) for f in os.listdir(self.source)]
f in os.listdir(self.source)]
else: else:
cfg_files = [self.source] cfg_files = [self.source]
@ -53,12 +52,14 @@ class ReadOnlyTokenFile(BasePlugin):
index = 1 index = 1
for f in cfg_files: for f in cfg_files:
for line in [l.strip() for l in open(f).readlines()]: for line in [l.strip() for l in open(f).readlines()]:
if line and not line.startswith('#'): if line and not line.startswith("#"):
try: try:
tok, target = re.split(r':\s', line) tok, target = re.split(r":\s", line)
self._targets[tok] = target.strip().rsplit(':', 1) self._targets[tok] = target.strip().rsplit(":", 1)
except ValueError: except ValueError:
logger.error("Syntax error in %s on line %d" % (self.source, index)) logger.error(
"Syntax error in %s on line %d" % (self.source, index)
)
index += 1 index += 1
def lookup(self, token): def lookup(self, token):
@ -83,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile):
return super().lookup(token) return super().lookup(token)
class TokenFileName(BasePlugin): class TokenFileName(BasePlugin):
# source is a directory # source is a directory
# token is filename # token is filename
@ -96,7 +98,7 @@ class TokenFileName(BasePlugin):
token = os.path.basename(token) token = os.path.basename(token)
path = os.path.join(self.source, token) path = os.path.join(self.source, token)
if os.path.exists(path): if os.path.exists(path):
return open(path).read().strip().split(':') return open(path).read().strip().split(":")
else: else:
return None return None
@ -109,9 +111,9 @@ class BaseTokenAPI(BasePlugin):
# in this file can be used w/o unnecessary dependencies # in this file can be used w/o unnecessary dependencies
def process_result(self, resp): def process_result(self, resp):
host, port = resp.text.split(':') host, port = resp.text.split(":")
port = port.encode('ascii','ignore') port = port.encode("ascii", "ignore")
return [ host, port ] return [host, port]
def lookup(self, token): def lookup(self, token):
import requests import requests
@ -130,7 +132,7 @@ class JSONTokenApi(BaseTokenAPI):
def process_result(self, resp): def process_result(self, resp):
resp_json = resp.json() resp_json = resp.json()
return (resp_json['host'], resp_json['port']) return (resp_json["host"], resp_json["port"])
class JWTTokenApi(BasePlugin): class JWTTokenApi(BasePlugin):
@ -145,7 +147,7 @@ class JWTTokenApi(BasePlugin):
key = jwk.JWK() key = jwk.JWK()
try: try:
with open(self.source, 'rb') as key_file: with open(self.source, "rb") as key_file:
key_data = key_file.read() key_data = key_file.read()
except Exception as e: except Exception as e:
logger.error("Error loading key file: %s" % str(e)) logger.error("Error loading key file: %s" % str(e))
@ -155,39 +157,41 @@ class JWTTokenApi(BasePlugin):
key.import_from_pem(key_data) key.import_from_pem(key_data)
except: except:
try: try:
key.import_key(k=key_data.decode('utf-8'),kty='oct') key.import_key(k=key_data.decode("utf-8"), kty="oct")
except: except:
logger.error('Failed to correctly parse key data!') logger.error("Failed to correctly parse key data!")
return None return None
try: try:
token = jwt.JWT(key=key, jwt=token) token = jwt.JWT(key=key, jwt=token)
parsed_header = json.loads(token.header) parsed_header = json.loads(token.header)
if 'enc' in parsed_header: if "enc" in parsed_header:
# Token is encrypted, so we need to decrypt by passing the claims to a new instance # Token is encrypted, so we need to decrypt by passing the claims to a new instance
token = jwt.JWT(key=key, jwt=token.claims) token = jwt.JWT(key=key, jwt=token.claims)
parsed = json.loads(token.claims) parsed = json.loads(token.claims)
if 'nbf' in parsed: if "nbf" in parsed:
# Not Before is present, so we need to check it # Not Before is present, so we need to check it
if time.time() < parsed['nbf']: if time.time() < parsed["nbf"]:
logger.warning('Token can not be used yet!') logger.warning("Token can not be used yet!")
return None return None
if 'exp' in parsed: if "exp" in parsed:
# Expiration time is present, so we need to check it # Expiration time is present, so we need to check it
if time.time() > parsed['exp']: if time.time() > parsed["exp"]:
logger.warning('Token has expired!') logger.warning("Token has expired!")
return None return None
return (parsed['host'], parsed['port']) return (parsed["host"], parsed["port"])
except Exception as e: except Exception as e:
logger.error("Failed to parse token: %s" % str(e)) logger.error("Failed to parse token: %s" % str(e))
return None return None
except ImportError: except ImportError:
logger.error("package jwcrypto not found, are you sure you've installed it correctly?") logger.error(
"package jwcrypto not found, are you sure you've installed it correctly?"
)
return None return None
@ -251,6 +255,7 @@ class TokenRedis(BasePlugin):
pip install redis pip install redis
""" """
def __init__(self, src): def __init__(self, src):
try: try:
import redis import redis
@ -285,7 +290,9 @@ class TokenRedis(BasePlugin):
if not self._password: if not self._password:
self._password = None self._password = None
elif len(fields) == 5: elif len(fields) == 5:
self._server, self._port, self._db, self._password, self._namespace = fields self._server, self._port, self._db, self._password, self._namespace = (
fields
)
if not self._port: if not self._port:
self._port = 6379 self._port = 6379
if not self._db: if not self._db:
@ -301,24 +308,30 @@ class TokenRedis(BasePlugin):
if self._namespace: if self._namespace:
self._namespace += ":" self._namespace += ":"
logger.info("TokenRedis backend initialized (%s:%s)" % logger.info(
(self._server, self._port)) "TokenRedis backend initialized (%s:%s)" % (self._server, self._port)
)
except ValueError: except ValueError:
logger.error("The provided --token-source='%s' is not in the " logger.error(
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" % "The provided --token-source='%s' is not in the "
src) "expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]"
% src
)
sys.exit() sys.exit()
def lookup(self, token): def lookup(self, token):
try: try:
import redis import redis
except ImportError: except ImportError:
logger.error("package redis not found, are you sure you've installed them correctly?") logger.error(
"package redis not found, are you sure you've installed them correctly?"
)
sys.exit() sys.exit()
logger.info("resolving token '%s'" % token) logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port, client = redis.Redis(
db=self._db, password=self._password) host=self._server, port=self._port, db=self._db, password=self._password
)
stuff = client.get(self._namespace + token) stuff = client.get(self._namespace + token)
if stuff is None: if stuff is None:
return None return None
@ -330,14 +343,14 @@ class TokenRedis(BasePlugin):
combo = json.loads(responseStr) combo = json.loads(responseStr)
host, port = combo["host"].split(":") host, port = combo["host"].split(":")
except ValueError: except ValueError:
logger.error("Unable to decode JSON token: %s" % logger.error("Unable to decode JSON token: %s" % responseStr)
responseStr)
return None return None
except KeyError: except KeyError:
logger.error("Unable to find 'host' key in JSON token: %s" % logger.error(
responseStr) "Unable to find 'host' key in JSON token: %s" % responseStr
)
return None return None
elif re.match(r'\S+:\S+', responseStr): elif re.match(r"\S+:\S+", responseStr):
host, port = responseStr.split(":") host, port = responseStr.split(":")
else: else:
logger.error("Unable to parse token: %s" % responseStr) logger.error("Unable to parse token: %s" % responseStr)
@ -368,7 +381,7 @@ class UnixDomainSocketDirectory(BasePlugin):
if not stat.S_ISSOCK(os.stat(uds_path).st_mode): if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
return None return None
return [ 'unix_socket', uds_path ] return ["unix_socket", uds_path]
except Exception as e: except Exception as e:
logger.error("Error finding unix domain socket: %s" % str(e)) logger.error("Error finding unix domain socket: %s" % str(e))
return None return None

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Python WebSocket library Python WebSocket library
Copyright 2011 Joel Martin Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman Copyright 2016 Pierre Ossman
@ -10,7 +10,7 @@ Supports following protocol versions:
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07 - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
- http://tools.ietf.org/html/rfc6455 - http://tools.ietf.org/html/rfc6455
''' """
import sys import sys
import array import array
@ -28,14 +28,19 @@ try:
import numpy import numpy
except ImportError: except ImportError:
import warnings import warnings
warnings.warn("no 'numpy' module, HyBi protocol will be slower") warnings.warn("no 'numpy' module, HyBi protocol will be slower")
numpy = None numpy = None
class WebSocketWantReadError(ssl.SSLWantReadError): class WebSocketWantReadError(ssl.SSLWantReadError):
pass pass
class WebSocketWantWriteError(ssl.SSLWantWriteError): class WebSocketWantWriteError(ssl.SSLWantWriteError):
pass pass
class WebSocket: class WebSocket:
"""WebSocket protocol socket like class. """WebSocket protocol socket like class.
@ -73,11 +78,11 @@ class WebSocket:
self._state = "new" self._state = "new"
self._partial_msg = b'' self._partial_msg = b""
self._recv_buffer = b'' self._recv_buffer = b""
self._recv_queue = [] self._recv_queue = []
self._send_buffer = b'' self._send_buffer = b""
self._previous_sendmsg = None self._previous_sendmsg = None
@ -91,16 +96,22 @@ class WebSocket:
def __getattr__(self, name): def __getattr__(self, name):
# These methods are just redirected to the underlying socket # These methods are just redirected to the underlying socket
if name in ["fileno", if name in [
"getpeername", "getsockname", "fileno",
"getsockopt", "setsockopt", "getpeername",
"gettimeout", "settimeout", "getsockname",
"setblocking"]: "getsockopt",
"setsockopt",
"gettimeout",
"settimeout",
"setblocking",
]:
assert self.socket is not None assert self.socket is not None
return getattr(self.socket, name) return getattr(self.socket, name)
else: else:
raise AttributeError("%s instance has no attribute '%s'" % raise AttributeError(
(self.__class__.__name__, name)) "%s instance has no attribute '%s'" % (self.__class__.__name__, name)
)
def connect(self, uri, origin=None, protocols=[]): def connect(self, uri, origin=None, protocols=[]):
"""Establishes a new connection to a WebSocket server. """Establishes a new connection to a WebSocket server.
@ -118,8 +129,7 @@ class WebSocket:
connect() must retain the same arguments. connect() must retain the same arguments.
""" """
self.client = True; self.client = True
uri = urlparse(uri) uri = urlparse(uri)
port = uri.port port = uri.port
@ -140,8 +150,9 @@ class WebSocket:
if uri.scheme in ("wss", "https"): if uri.scheme in ("wss", "https"):
context = ssl.create_default_context() context = ssl.create_default_context()
self.socket = context.wrap_socket(self.socket, self.socket = context.wrap_socket(
server_hostname=uri.hostname) self.socket, server_hostname=uri.hostname
)
self._state = "ssl_handshake" self._state = "ssl_handshake"
else: else:
self._state = "headers" self._state = "headers"
@ -151,7 +162,7 @@ class WebSocket:
self._state = "headers" self._state = "headers"
if self._state == "headers": if self._state == "headers":
self._key = '' self._key = ""
for i in range(16): for i in range(16):
self._key += chr(random.randrange(256)) self._key += chr(random.randrange(256))
self._key = b64encode(self._key.encode("latin-1")).decode("ascii") self._key = b64encode(self._key.encode("latin-1")).decode("ascii")
@ -184,10 +195,10 @@ class WebSocket:
if not self._recv(): if not self._recv():
raise Exception("Socket closed unexpectedly") raise Exception("Socket closed unexpectedly")
if self._recv_buffer.find(b'\r\n\r\n') == -1: if self._recv_buffer.find(b"\r\n\r\n") == -1:
raise WebSocketWantReadError raise WebSocketWantReadError
(request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1) (request, self._recv_buffer) = self._recv_buffer.split(b"\r\n", 1)
request = request.decode("latin-1") request = request.decode("latin-1")
words = request.split() words = request.split()
@ -196,17 +207,17 @@ class WebSocket:
if words[1] != "101": if words[1] != "101":
raise Exception("WebSocket request denied: %s" % " ".join(words[1:])) raise Exception("WebSocket request denied: %s" % " ".join(words[1:]))
(headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1) (headers, self._recv_buffer) = self._recv_buffer.split(b"\r\n\r\n", 1)
headers = headers.decode('latin-1') + '\r\n' headers = headers.decode("latin-1") + "\r\n"
headers = email.message_from_string(headers) headers = email.message_from_string(headers)
if headers.get("Upgrade", "").lower() != "websocket": if headers.get("Upgrade", "").lower() != "websocket":
print(type(headers)) print(type(headers))
raise Exception("Missing or incorrect upgrade header") raise Exception("Missing or incorrect upgrade header")
accept = headers.get('Sec-WebSocket-Accept') accept = headers.get("Sec-WebSocket-Accept")
if accept is None: if accept is None:
raise Exception("Missing Sec-WebSocket-Accept header"); raise Exception("Missing Sec-WebSocket-Accept header")
expected = sha1((self._key + self.GUID).encode("ascii")).digest() expected = sha1((self._key + self.GUID).encode("ascii")).digest()
expected = b64encode(expected).decode("ascii") expected = b64encode(expected).decode("ascii")
@ -214,9 +225,9 @@ class WebSocket:
del self._key del self._key
if accept != expected: if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header"); raise Exception("Invalid Sec-WebSocket-Accept header")
self.protocol = headers.get('Sec-WebSocket-Protocol') self.protocol = headers.get("Sec-WebSocket-Protocol")
if len(protocols) == 0: if len(protocols) == 0:
if self.protocol is not None: if self.protocol is not None:
raise Exception("Unexpected Sec-WebSocket-Protocol header") raise Exception("Unexpected Sec-WebSocket-Protocol header")
@ -256,34 +267,34 @@ class WebSocket:
if headers.get("upgrade", "").lower() != "websocket": if headers.get("upgrade", "").lower() != "websocket":
raise Exception("Missing or incorrect upgrade header") raise Exception("Missing or incorrect upgrade header")
ver = headers.get('Sec-WebSocket-Version') ver = headers.get("Sec-WebSocket-Version")
if ver is None: if ver is None:
raise Exception("Missing Sec-WebSocket-Version header"); raise Exception("Missing Sec-WebSocket-Version header")
# HyBi-07 report version 7 # HyBi-07 report version 7
# HyBi-08 - HyBi-12 report version 8 # HyBi-08 - HyBi-12 report version 8
# HyBi-13 reports version 13 # HyBi-13 reports version 13
if ver in ['7', '8', '13']: if ver in ["7", "8", "13"]:
self.version = "hybi-%02d" % int(ver) self.version = "hybi-%02d" % int(ver)
else: else:
raise Exception("Unsupported protocol version %s" % ver) raise Exception("Unsupported protocol version %s" % ver)
key = headers.get('Sec-WebSocket-Key') key = headers.get("Sec-WebSocket-Key")
if key is None: if key is None:
raise Exception("Missing Sec-WebSocket-Key header"); raise Exception("Missing Sec-WebSocket-Key header")
# Generate the hash value for the accept header # Generate the hash value for the accept header
accept = sha1((key + self.GUID).encode("ascii")).digest() accept = sha1((key + self.GUID).encode("ascii")).digest()
accept = b64encode(accept).decode("ascii") accept = b64encode(accept).decode("ascii")
self.protocol = '' self.protocol = ""
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',') protocols = headers.get("Sec-WebSocket-Protocol", "").split(",")
if protocols: if protocols:
self.protocol = self.select_subprotocol(protocols) self.protocol = self.select_subprotocol(protocols)
# We are required to choose one of the protocols # We are required to choose one of the protocols
# presented by the client # presented by the client
if self.protocol not in protocols: if self.protocol not in protocols:
raise Exception('Invalid protocol selected') raise Exception("Invalid protocol selected")
self.send_response(101, "Switching Protocols") self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket") self.send_header("Upgrade", "websocket")
@ -461,7 +472,7 @@ class WebSocket:
def send_request(self, type, path): def send_request(self, type, path):
self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path)) self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path))
def ping(self, data=b''): def ping(self, data=b""):
"""Write a ping message to the WebSocket """Write a ping message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient WebSocketWantWriteError can be raised if there is insufficient
@ -486,7 +497,7 @@ class WebSocket:
self._previous_sendmsg = data self._previous_sendmsg = data
raise raise
def pong(self, data=b''): def pong(self, data=b""):
"""Write a pong message to the WebSocket """Write a pong message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient WebSocketWantWriteError can be raised if there is insufficient
@ -540,7 +551,7 @@ class WebSocket:
self._sent_close = True self._sent_close = True
msg = b'' msg = b""
if code is not None: if code is not None:
msg += struct.pack(">H", code) msg += struct.pack(">H", code)
if reason is not None: if reason is not None:
@ -602,7 +613,7 @@ class WebSocket:
frame = self._decode_hybi(self._recv_buffer) frame = self._decode_hybi(self._recv_buffer)
if frame is None: if frame is None:
break break
self._recv_buffer = self._recv_buffer[frame['length']:] self._recv_buffer = self._recv_buffer[frame["length"] :]
self._recv_queue.append(frame) self._recv_queue.append(frame)
return True return True
@ -612,29 +623,39 @@ class WebSocket:
while self._recv_queue: while self._recv_queue:
frame = self._recv_queue.pop(0) frame = self._recv_queue.pop(0)
if not self.client and not frame['masked']: if not self.client and not frame["masked"]:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked") self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked"
)
continue continue
if self.client and frame['masked']: if self.client and frame["masked"]:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked") self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked")
continue continue
if frame["opcode"] == 0x0: if frame["opcode"] == 0x0:
if not self._partial_msg: if not self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame") self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Unexpected continuation frame",
)
continue continue
self._partial_msg += frame["payload"] self._partial_msg += frame["payload"]
if frame["fin"]: if frame["fin"]:
msg = self._partial_msg msg = self._partial_msg
self._partial_msg = b'' self._partial_msg = b""
return msg return msg
elif frame["opcode"] == 0x1: elif frame["opcode"] == 0x1:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported") self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported"
)
elif frame["opcode"] == 0x2: elif frame["opcode"] == 0x2:
if self._partial_msg: if self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame") self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame"
)
continue continue
if frame["fin"]: if frame["fin"]:
@ -652,7 +673,9 @@ class WebSocket:
return None return None
if not frame["fin"]: if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close") self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close"
)
continue continue
code = None code = None
@ -664,7 +687,11 @@ class WebSocket:
try: try:
reason = reason.decode("UTF-8") reason = reason.decode("UTF-8")
except UnicodeDecodeError: except UnicodeDecodeError:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close") self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Invalid UTF-8 in close",
)
continue continue
if code is None: if code is None:
@ -679,18 +706,26 @@ class WebSocket:
return None return None
elif frame["opcode"] == 0x9: elif frame["opcode"] == 0x9:
if not frame["fin"]: if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping") self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping"
)
continue continue
self.handle_ping(frame["payload"]) self.handle_ping(frame["payload"])
elif frame["opcode"] == 0xA: elif frame["opcode"] == 0xA:
if not frame["fin"]: if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong") self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong"
)
continue continue
self.handle_pong(frame["payload"]) self.handle_pong(frame["payload"])
else: else:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"]) self.shutdown(
socket.SHUT_RDWR,
1003,
"Unsupported: Unknown opcode 0x%02x" % frame["opcode"],
)
raise WebSocketWantReadError raise WebSocketWantReadError
@ -731,7 +766,7 @@ class WebSocket:
def _sendmsg(self, opcode, msg): def _sendmsg(self, opcode, msg):
# Sends a standard data message # Sends a standard data message
if self.client: if self.client:
mask = b'' mask = b""
for i in range(4): for i in range(4):
mask += random.randrange(256).to_bytes() mask += random.randrange(256).to_bytes()
frame = self._encode_hybi(opcode, msg, mask) frame = self._encode_hybi(opcode, msg, mask)
@ -755,35 +790,36 @@ class WebSocket:
plen = len(buf) plen = len(buf)
pstart = 0 pstart = 0
pend = plen pend = plen
b = c = b'' b = c = b""
if plen >= 4: if plen >= 4:
dtype=numpy.dtype('<u4') dtype = numpy.dtype("<u4")
if sys.byteorder == 'big': if sys.byteorder == "big":
dtype = dtype.newbyteorder('>') dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=1) mask = numpy.frombuffer(mask, dtype, count=1)
data = numpy.frombuffer(buf, dtype, count=int(plen / 4)) data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
#b = numpy.bitwise_xor(data, mask).data # b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tobytes() b = numpy.bitwise_xor(data, mask).tobytes()
if plen % 4: if plen % 4:
dtype=numpy.dtype('B') dtype = numpy.dtype("B")
if sys.byteorder == 'big': if sys.byteorder == "big":
dtype = dtype.newbyteorder('>') dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=(plen % 4)) mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
data = numpy.frombuffer(buf, dtype, data = numpy.frombuffer(
offset=plen - (plen % 4), count=(plen % 4)) buf, dtype, offset=plen - (plen % 4), count=(plen % 4)
)
c = numpy.bitwise_xor(data, mask).tobytes() c = numpy.bitwise_xor(data, mask).tobytes()
return b + c return b + c
else: else:
# Slower fallback # Slower fallback
data = array.array('B') data = array.array("B")
data.frombytes(buf) data.frombytes(buf)
for i in range(len(data)): for i in range(len(data)):
data[i] ^= mask[i % 4] data[i] ^= mask[i % 4]
return data.tobytes() return data.tobytes()
def _encode_hybi(self, opcode, buf, mask_key=None, fin=True): def _encode_hybi(self, opcode, buf, mask_key=None, fin=True):
""" Encode a HyBi style WebSocket frame. """Encode a HyBi style WebSocket frame.
Optional opcode: Optional opcode:
0x0 - continuation 0x0 - continuation
0x1 - text frame 0x1 - text frame
@ -793,7 +829,7 @@ class WebSocket:
0xA - pong 0xA - pong
""" """
b1 = opcode & 0x0f b1 = opcode & 0x0F
if fin: if fin:
b1 |= 0x80 b1 |= 0x80
@ -804,11 +840,11 @@ class WebSocket:
payload_len = len(buf) payload_len = len(buf)
if payload_len <= 125: if payload_len <= 125:
header = struct.pack('>BB', b1, payload_len | mask_bit) header = struct.pack(">BB", b1, payload_len | mask_bit)
elif payload_len > 125 and payload_len < 65536: elif payload_len > 125 and payload_len < 65536:
header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len) header = struct.pack(">BBH", b1, 126 | mask_bit, payload_len)
elif payload_len >= 65536: elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len) header = struct.pack(">BBQ", b1, 127 | mask_bit, payload_len)
if mask_key is not None: if mask_key is not None:
return header + mask_key + buf return header + mask_key + buf
@ -816,7 +852,7 @@ class WebSocket:
return header + buf return header + buf
def _decode_hybi(self, buf): def _decode_hybi(self, buf):
""" Decode HyBi style WebSocket packets. """Decode HyBi style WebSocket packets.
Returns: Returns:
{'fin' : boolean, {'fin' : boolean,
'opcode' : number, 'opcode' : number,
@ -825,11 +861,7 @@ class WebSocket:
'payload' : decoded_buffer} 'payload' : decoded_buffer}
""" """
f = {'fin' : 0, f = {"fin": 0, "opcode": 0, "masked": False, "length": 0, "payload": None}
'opcode' : 0,
'masked' : False,
'length' : 0,
'payload' : None}
blen = len(buf) blen = len(buf)
hlen = 2 hlen = 2
@ -838,39 +870,38 @@ class WebSocket:
return None return None
b1, b2 = struct.unpack(">BB", buf[:2]) b1, b2 = struct.unpack(">BB", buf[:2])
f['opcode'] = b1 & 0x0f f["opcode"] = b1 & 0x0F
f['fin'] = not not (b1 & 0x80) f["fin"] = not not (b1 & 0x80)
f['masked'] = not not (b2 & 0x80) f["masked"] = not not (b2 & 0x80)
if f['masked']: if f["masked"]:
hlen += 4 hlen += 4
if blen < hlen: if blen < hlen:
return None return None
length = b2 & 0x7f length = b2 & 0x7F
if length == 126: if length == 126:
hlen += 2 hlen += 2
if blen < hlen: if blen < hlen:
return None return None
length, = struct.unpack('>H', buf[2:4]) (length,) = struct.unpack(">H", buf[2:4])
elif length == 127: elif length == 127:
hlen += 8 hlen += 8
if blen < hlen: if blen < hlen:
return None return None
length, = struct.unpack('>Q', buf[2:10]) (length,) = struct.unpack(">Q", buf[2:10])
f['length'] = hlen + length f["length"] = hlen + length
if blen < f['length']: if blen < f["length"]:
return None return None
if f['masked']: if f["masked"]:
# unmask payload # unmask payload
mask_key = buf[hlen-4:hlen] mask_key = buf[hlen - 4 : hlen]
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key) f["payload"] = self._unmask(buf[hlen : (hlen + length)], mask_key)
else: else:
f['payload'] = buf[hlen:(hlen+length)] f["payload"] = buf[hlen : (hlen + length)]
return f return f

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
A WebSocket to TCP socket proxy with support for "wss://" encryption. A WebSocket to TCP socket proxy with support for "wss://" encryption.
Copyright 2011 Joel Martin Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -9,7 +9,7 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' """
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
@ -20,8 +20,8 @@ from websockify import websockifyserver
from websockify import auth_plugins as auth from websockify import auth_plugins as auth
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
buffer_size = 65536 buffer_size = 65536
traffic_legend = """ traffic_legend = """
@ -38,7 +38,7 @@ Traffic Legend:
def send_auth_error(self, ex): def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg) self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html') self.send_header("Content-Type", "text/html")
for name, val in ex.headers.items(): for name, val in ex.headers.items():
self.send_header(name, val) self.send_header(name, val)
@ -49,7 +49,7 @@ Traffic Legend:
return return
host, port = self.get_target(self.server.token_plugin) host, port = self.get_target(self.server.token_plugin)
if host == 'unix_socket': if host == "unix_socket":
self.server.unix_target = port self.server.unix_target = port
else: else:
@ -62,7 +62,7 @@ Traffic Legend:
# clear out any existing SSL_ headers that the client might # clear out any existing SSL_ headers that the client might
# have maliciously set # have maliciously set
ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ] ssl_headers = [h for h in self.headers if h.startswith("SSL_")]
for h in ssl_headers: for h in ssl_headers:
del self.headers[h] del self.headers[h]
@ -70,19 +70,21 @@ Traffic Legend:
# get client certificate data # get client certificate data
client_cert_data = self.request.getpeercert() client_cert_data = self.request.getpeercert()
# extract subject information # extract subject information
client_cert_subject = client_cert_data['subject'] client_cert_subject = client_cert_data["subject"]
# flatten data structure # flatten data structure
client_cert_subject = dict([x[0] for x in client_cert_subject]) client_cert_subject = dict([x[0] for x in client_cert_subject])
# add common name to headers (apache +StdEnvVars style) # add common name to headers (apache +StdEnvVars style)
self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName'] self.headers["SSL_CLIENT_S_DN_CN"] = client_cert_subject["commonName"]
except (TypeError, AttributeError, KeyError): except (TypeError, AttributeError, KeyError):
# not a SSL connection or client presented no certificate with valid data # not a SSL connection or client presented no certificate with valid data
pass pass
try: try:
self.server.auth_plugin.authenticate( self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host, headers=self.headers,
target_port=self.server.target_port) target_host=self.server.target_host,
target_port=self.server.target_port,
)
except auth.AuthenticationError: except auth.AuthenticationError:
ex = sys.exc_info()[1] ex = sys.exc_info()[1]
self.send_auth_error(ex) self.send_auth_error(ex)
@ -96,26 +98,37 @@ Traffic Legend:
# Connect to the target # Connect to the target
if self.server.wrap_cmd: if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) msg = "connecting to command: '%s' (port %s)" % (
" ".join(self.server.wrap_cmd),
self.server.target_port,
)
elif self.server.unix_target: elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.server.unix_target msg = "connecting to unix socket: %s" % self.server.unix_target
else: else:
msg = "connecting to: %s:%s" % ( msg = "connecting to: %s:%s" % (
self.server.target_host, self.server.target_port) self.server.target_host,
self.server.target_port,
)
if self.server.ssl_target: if self.server.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
self.log_message(msg) self.log_message(msg)
try: try:
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host, tsock = websockifyserver.WebSockifyServer.socket(
self.server.target_host,
self.server.target_port, self.server.target_port,
connect=True, connect=True,
use_ssl=self.server.ssl_target, use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target) unix_socket=self.server.unix_target,
)
except Exception as e: except Exception as e:
self.log_message("Failed to connect to %s:%s: %s", self.log_message(
self.server.target_host, self.server.target_port, e) "Failed to connect to %s:%s: %s",
self.server.target_host,
self.server.target_port,
e,
)
raise self.CClose(1011, "Failed to connect to downstream server") raise self.CClose(1011, "Failed to connect to downstream server")
# Option unavailable when listening to unix socket # Option unavailable when listening to unix socket
@ -134,8 +147,11 @@ Traffic Legend:
tsock.shutdown(socket.SHUT_RDWR) tsock.shutdown(socket.SHUT_RDWR)
tsock.close() tsock.close()
if self.verbose: if self.verbose:
self.log_message("%s:%s: Closed target", self.log_message(
self.server.target_host, self.server.target_port) "%s:%s: Closed target",
self.server.target_host,
self.server.target_port,
)
def get_target(self, target_plugin): def get_target(self, target_plugin):
""" """
@ -149,20 +165,20 @@ Traffic Legend:
if self.host_token: if self.host_token:
# Use hostname as token # Use hostname as token
token = self.headers.get('Host') token = self.headers.get("Host")
# Remove port from hostname, as it'll always be the one where # Remove port from hostname, as it'll always be the one where
# websockify listens (unless something between the client and # websockify listens (unless something between the client and
# websockify is redirecting traffic, but that's beside the point) # websockify is redirecting traffic, but that's beside the point)
if token: if token:
token = token.partition(':')[0] token = token.partition(":")[0]
else: else:
# Extract the token parameter from url # Extract the token parameter from url
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
if 'token' in args and len(args['token']): if "token" in args and len(args["token"]):
token = args['token'][0].rstrip('\n') token = args["token"][0].rstrip("\n")
else: else:
token = None token = None
@ -200,13 +216,15 @@ Traffic Legend:
self.heartbeat = now + self.server.heartbeat self.heartbeat = now + self.server.heartbeat
self.send_ping() self.send_ping()
if tqueue: wlist.append(target) if tqueue:
if cqueue or c_pend: wlist.append(self.request) wlist.append(target)
if cqueue or c_pend:
wlist.append(self.request)
try: try:
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
except OSError: except OSError:
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
if hasattr(exc, 'errno'): if hasattr(exc, "errno"):
err = exc.errno err = exc.errno
else: else:
err = exc[0] err = exc[0]
@ -216,7 +234,8 @@ Traffic Legend:
else: else:
continue continue
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if self.request in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
@ -230,8 +249,7 @@ Traffic Legend:
tqueue.extend(bufs) tqueue.extend(bufs)
if closed: if closed:
while len(tqueue) != 0:
while (len(tqueue) != 0):
# Send queued client data to the target # Send queued client data to the target
dat = tqueue.pop(0) dat = tqueue.pop(0)
sent = target.send(dat) sent = target.send(dat)
@ -244,10 +262,12 @@ Traffic Legend:
# TODO: What about blocking on client socket? # TODO: What about blocking on client socket?
if self.verbose: if self.verbose:
self.log_message("%s:%s: Client closed connection", self.log_message(
self.server.target_host, self.server.target_port) "%s:%s: Client closed connection",
raise self.CClose(closed['code'], closed['reason']) self.server.target_host,
self.server.target_port,
)
raise self.CClose(closed["code"], closed["reason"])
if target in outs: if target in outs:
# Send queued client data to the target # Send queued client data to the target
@ -260,29 +280,31 @@ Traffic Legend:
tqueue.insert(0, dat[sent:]) tqueue.insert(0, dat[sent:])
self.print_traffic(".>") self.print_traffic(".>")
if target in ins: if target in ins:
# Receive target data, encode it and queue for client # Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size) buf = target.recv(self.buffer_size)
if len(buf) == 0: if len(buf) == 0:
# Target socket closed, flushing queues and closing client-side websocket # Target socket closed, flushing queues and closing client-side websocket
# Send queued target data to the client # Send queued target data to the client
if len(cqueue) != 0: if len(cqueue) != 0:
c_pend = True c_pend = True
while(c_pend): while c_pend:
c_pend = self.send_frames(cqueue) c_pend = self.send_frames(cqueue)
cqueue = [] cqueue = []
if self.verbose: if self.verbose:
self.log_message("%s:%s: Target closed connection", self.log_message(
self.server.target_host, self.server.target_port) "%s:%s: Target closed connection",
self.server.target_host,
self.server.target_port,
)
raise self.CClose(1000, "Target closed") raise self.CClose(1000, "Target closed")
cqueue.append(buf) cqueue.append(buf)
self.print_traffic("{") self.print_traffic("{")
class WebSocketProxy(websockifyserver.WebSockifyServer): class WebSocketProxy(websockifyserver.WebSockifyServer):
""" """
Proxy traffic to and from a WebSockets client to a normal TCP Proxy traffic to and from a WebSockets client to a normal TCP
@ -293,27 +315,29 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options # Save off proxy specific options
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None) self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop('wrap_mode', None) self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop('unix_target', None) self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop("ssl_target", None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop("token_plugin", None)
self.host_token = kwargs.pop('host_token', None) self.host_token = kwargs.pop("host_token", None)
self.auth_plugin = kwargs.pop('auth_plugin', None) self.auth_plugin = kwargs.pop("auth_plugin", None)
# Last 3 timestamps command was run # Last 3 timestamps command was run
self.wrap_times = [0, 0, 0] self.wrap_times = [0, 0, 0]
if self.wrap_cmd: if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0]) wsdir = os.path.dirname(sys.argv[0])
rebinder_path = [os.path.join(wsdir, "..", "lib"), rebinder_path = [
os.path.join(wsdir, "..", "lib"),
os.path.join(wsdir, "..", "lib", "websockify"), os.path.join(wsdir, "..", "lib", "websockify"),
os.path.join(wsdir, ".."), os.path.join(wsdir, ".."),
wsdir] wsdir,
]
self.rebinder = None self.rebinder = None
for rdir in rebinder_path: for rdir in rebinder_path:
@ -329,17 +353,22 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.target_host = "127.0.0.1" # Loopback self.target_host = "127.0.0.1" # Loopback
# Find a free high port # Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0)) sock.bind(("", 0))
self.target_port = sock.getsockname()[1] self.target_port = sock.getsockname()[1]
sock.close() sock.close()
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist # Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ]) ld_preloads = filter(
None, [self.rebinder, os.environ.get("LD_PRELOAD", None)]
)
os.environ.update({ os.environ.update(
{
"LD_PRELOAD": os.pathsep.join(ld_preloads), "LD_PRELOAD": os.pathsep.join(ld_preloads),
"REBIND_OLD_PORT": str(kwargs['listen_port']), "REBIND_OLD_PORT": str(kwargs["listen_port"]),
"REBIND_NEW_PORT": str(self.target_port)}) "REBIND_NEW_PORT": str(self.target_port),
}
)
super().__init__(RequestHandlerClass, *args, **kwargs) super().__init__(RequestHandlerClass, *args, **kwargs)
@ -348,7 +377,8 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.wrap_times.append(time.time()) self.wrap_times.append(time.time())
self.wrap_times.pop(0) self.wrap_times.pop(0)
self.cmd = subprocess.Popen( self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup
)
self.spawn_message = True self.spawn_message = True
def started(self): def started(self):
@ -371,10 +401,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
if self.token_plugin: if self.token_plugin:
msg = " - proxying from %s to targets generated by %s" % ( msg = " - proxying from %s to targets generated by %s" % (
src_string, type(self.token_plugin).__name__) src_string,
type(self.token_plugin).__name__,
)
else: else:
msg = " - proxying from %s to %s" % ( msg = " - proxying from %s to %s" % (src_string, dst_string)
src_string, dst_string)
if self.ssl_target: if self.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
@ -401,7 +432,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
sys.exit(ret) sys.exit(ret)
elif self.wrap_mode == "respawn": elif self.wrap_mode == "respawn":
now = time.time() now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times) avg = sum(self.wrap_times) / len(self.wrap_times)
if (now - avg) < 10: if (now - avg) < 10:
# 3 times in the last 10 seconds # 3 times in the last 10 seconds
if self.spawn_message: if self.spawn_message:
@ -418,15 +449,25 @@ def _subprocess_setup():
SSL_OPTIONS = { SSL_OPTIONS = {
'default': ssl.OP_ALL, "default": ssl.OP_ALL,
'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | "tlsv1_1": ssl.PROTOCOL_SSLv23
ssl.OP_NO_TLSv1, | ssl.OP_NO_SSLv2
'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | | ssl.OP_NO_SSLv3
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1, | ssl.OP_NO_TLSv1,
'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | "tlsv1_2": ssl.PROTOCOL_SSLv23
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2, | ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1,
"tlsv1_3": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1
| ssl.OP_NO_TLSv1_2,
} }
def select_ssl_version(version): def select_ssl_version(version):
"""Returns SSL options for the most secure TSL version available on this """Returns SSL options for the most secure TSL version available on this
Python version""" Python version"""
@ -439,11 +480,11 @@ def select_ssl_version(version):
keys.sort() keys.sort()
fallback = keys[-1] fallback = keys[-1]
logger = logging.getLogger(WebSocketProxy.log_prefix) logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.warn("TLS version %s unsupported. Falling back to %s", logger.warn("TLS version %s unsupported. Falling back to %s", version, fallback)
version, fallback)
return SSL_OPTIONS[fallback] return SSL_OPTIONS[fallback]
def websockify_init(): def websockify_init():
# Setup basic logging to stderr. # Setup basic logging to stderr.
stderr_handler = logging.StreamHandler() stderr_handler = logging.StreamHandler()
@ -464,105 +505,198 @@ def websockify_init():
usage += "\n %prog [options]" usage += "\n %prog [options]"
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage) parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true", parser.add_option("--verbose", "-v", action="store_true", help="verbose messages")
help="verbose messages") parser.add_option("--traffic", action="store_true", help="per frame traffic")
parser.add_option("--traffic", action="store_true", parser.add_option(
help="per frame traffic") "--record", help="record sessions to FILE.[session_number]", metavar="FILE"
parser.add_option("--record", )
help="record sessions to FILE.[session_number]", metavar="FILE") parser.add_option(
parser.add_option("--daemon", "-D", "--daemon",
dest="daemon", action="store_true", "-D",
help="become a daemon (background process)") dest="daemon",
parser.add_option("--run-once", action="store_true", action="store_true",
help="handle a single WebSocket connection and exit") help="become a daemon (background process)",
parser.add_option("--timeout", type=int, default=0, )
help="after TIMEOUT seconds exit when not connected") parser.add_option(
parser.add_option("--idle-timeout", type=int, default=0, "--run-once",
help="server exits after TIMEOUT seconds if there are no " action="store_true",
"active connections") help="handle a single WebSocket connection and exit",
parser.add_option("--cert", default="self.pem", )
help="SSL certificate file") parser.add_option(
parser.add_option("--key", default=None, "--timeout",
help="SSL key file (if separate from cert)") type=int,
parser.add_option("--key-password", default=None, default=0,
help="SSL key password") help="after TIMEOUT seconds exit when not connected",
parser.add_option("--ssl-only", action="store_true", )
help="disallow non-encrypted client connections") parser.add_option(
parser.add_option("--ssl-target", action="store_true", "--idle-timeout",
help="connect to SSL target as SSL client") type=int,
parser.add_option("--verify-client", action="store_true", default=0,
help="server exits after TIMEOUT seconds if there are no active connections",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option("--key-password", default=None, help="SSL key password")
parser.add_option(
"--ssl-only",
action="store_true",
help="disallow non-encrypted client connections",
)
parser.add_option(
"--ssl-target", action="store_true", help="connect to SSL target as SSL client"
)
parser.add_option(
"--verify-client",
action="store_true",
help="require encrypted client to present a valid certificate " help="require encrypted client to present a valid certificate "
"(needs Python 2.7.9 or newer or Python 3.4 or newer)") "(needs Python 2.7.9 or newer or Python 3.4 or newer)",
parser.add_option("--cafile", metavar="FILE", )
parser.add_option(
"--cafile",
metavar="FILE",
help="file of concatenated certificates of authorities trusted " help="file of concatenated certificates of authorities trusted "
"for validating clients (only effective with --verify-client). " "for validating clients (only effective with --verify-client). "
"If omitted, system default list of CAs is used.") "If omitted, system default list of CAs is used.",
parser.add_option("--ssl-version", type="choice", default="default", )
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store", parser.add_option(
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)") "--ssl-version",
parser.add_option("--ssl-ciphers", action="store", type="choice",
default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"],
action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)",
)
parser.add_option(
"--ssl-ciphers",
action="store",
help="list of ciphers allowed for connection. For a list of " help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`") "supported ciphers run `openssl ciphers`",
parser.add_option("--unix-listen", )
help="listen to unix socket", metavar="FILE", default=None) parser.add_option(
parser.add_option("--unix-listen-mode", default=None, "--unix-listen", help="listen to unix socket", metavar="FILE", default=None
help="specify mode for unix socket (defaults to 0600)") )
parser.add_option("--unix-target", parser.add_option(
help="connect to unix socket target", metavar="FILE") "--unix-listen-mode",
parser.add_option("--inetd", default=None,
help="inetd mode, receive listening socket from stdin", action="store_true") help="specify mode for unix socket (defaults to 0600)",
parser.add_option("--web", default=None, metavar="DIR", )
help="run webserver on same port. Serve files from DIR.") parser.add_option(
parser.add_option("--web-auth", action="store_true", "--unix-target", help="connect to unix socket target", metavar="FILE"
help="require authentication to access webserver.") )
parser.add_option("--wrap-mode", default="exit", metavar="MODE", parser.add_option(
"--inetd",
help="inetd mode, receive listening socket from stdin",
action="store_true",
)
parser.add_option(
"--web",
default=None,
metavar="DIR",
help="run webserver on same port. Serve files from DIR.",
)
parser.add_option(
"--web-auth",
action="store_true",
help="require authentication to access webserver.",
)
parser.add_option(
"--wrap-mode",
default="exit",
metavar="MODE",
choices=["exit", "ignore", "respawn"], choices=["exit", "ignore", "respawn"],
help="action to take when the wrapped program exits " help="action to take when the wrapped program exits "
"or daemonizes: exit (default), ignore, respawn") "or daemonizes: exit (default), ignore, respawn",
parser.add_option("--prefer-ipv6", "-6", )
action="store_true", dest="source_is_ipv6", parser.add_option(
help="prefer IPv6 when resolving source_addr") "--prefer-ipv6",
parser.add_option("--libserver", action="store_true", "-6",
help="use Python library SocketServer engine") action="store_true",
parser.add_option("--target-config", metavar="FILE", dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr",
)
parser.add_option(
"--libserver",
action="store_true",
help="use Python library SocketServer engine",
)
parser.add_option(
"--target-config",
metavar="FILE",
dest="target_cfg", dest="target_cfg",
help="Configuration file containing valid targets " help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a " "in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form " "directory containing configuration files of this form "
"(DEPRECATED: use `--token-plugin TokenFile --token-source " "(DEPRECATED: use `--token-plugin TokenFile --token-source "
" path/to/token/file` instead)") " path/to/token/file` instead)",
parser.add_option("--token-plugin", default=None, metavar="CLASS", )
parser.add_option(
"--token-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.token_plugins, " help="use a Python class, usually one from websockify.token_plugins, "
"such as TokenFile, to process tokens into host:port pairs") "such as TokenFile, to process tokens into host:port pairs",
parser.add_option("--token-source", default=None, metavar="ARG", )
help="an argument to be passed to the token plugin " parser.add_option(
"on instantiation") "--token-source",
parser.add_option("--host-token", action="store_true", default=None,
metavar="ARG",
help="an argument to be passed to the token plugin on instantiation",
)
parser.add_option(
"--host-token",
action="store_true",
help="use the host HTTP header as token instead of the " help="use the host HTTP header as token instead of the "
"token URL query parameter") "token URL query parameter",
parser.add_option("--auth-plugin", default=None, metavar="CLASS", )
parser.add_option(
"--auth-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.auth_plugins, " help="use a Python class, usually one from websockify.auth_plugins, "
"such as BasicHTTPAuth, to determine if a connection is allowed") "such as BasicHTTPAuth, to determine if a connection is allowed",
parser.add_option("--auth-source", default=None, metavar="ARG", )
help="an argument to be passed to the auth plugin " parser.add_option(
"on instantiation") "--auth-source",
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL", default=None,
help="send a ping to the client every INTERVAL seconds") metavar="ARG",
parser.add_option("--log-file", metavar="FILE", help="an argument to be passed to the auth plugin on instantiation",
)
parser.add_option(
"--heartbeat",
type=int,
default=0,
metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds",
)
parser.add_option(
"--log-file",
metavar="FILE",
dest="log_file", dest="log_file",
help="File where logs will be saved") help="File where logs will be saved",
parser.add_option("--syslog", default=None, metavar="SERVER", )
parser.add_option(
"--syslog",
default=None,
metavar="SERVER",
help="Log to syslog server. SERVER can be local socket, " help="Log to syslog server. SERVER can be local socket, "
"such as /dev/log, or a UDP host:port pair.") "such as /dev/log, or a UDP host:port pair.",
parser.add_option("--legacy-syslog", action="store_true", )
parser.add_option(
"--legacy-syslog",
action="store_true",
help="Use the old syslog protocol instead of RFC 5424. " help="Use the old syslog protocol instead of RFC 5424. "
"Use this if the messages produced by websockify seem abnormal.") "Use this if the messages produced by websockify seem abnormal.",
parser.add_option("--file-only", action="store_true", )
help="use this to disable directory listings in web server.") parser.add_option(
"--file-only",
action="store_true",
help="use this to disable directory listings in web server.",
)
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
# Validate options. # Validate options.
if opts.token_source and not opts.token_plugin: if opts.token_source and not opts.token_plugin:
@ -583,11 +717,9 @@ def websockify_init():
if opts.legacy_syslog and not opts.syslog: if opts.legacy_syslog and not opts.syslog:
parser.error("You must use --syslog to use --legacy-syslog") parser.error("You must use --syslog to use --legacy-syslog")
opts.ssl_options = select_ssl_version(opts.ssl_version) opts.ssl_options = select_ssl_version(opts.ssl_version)
del opts.ssl_version del opts.ssl_version
if opts.log_file: if opts.log_file:
# Setup logging to user-specified file. # Setup logging to user-specified file.
opts.log_file = os.path.abspath(opts.log_file) opts.log_file = os.path.abspath(opts.log_file)
@ -601,9 +733,9 @@ def websockify_init():
if opts.syslog: if opts.syslog:
# Determine how to connect to syslog... # Determine how to connect to syslog...
if opts.syslog.count(':'): if opts.syslog.count(":"):
# User supplied a host:port pair. # User supplied a host:port pair.
syslog_host, syslog_port = opts.syslog.rsplit(':', 1) syslog_host, syslog_port = opts.syslog.rsplit(":", 1)
try: try:
syslog_port = int(syslog_port) syslog_port = int(syslog_port)
except ValueError: except ValueError:
@ -622,10 +754,12 @@ def websockify_init():
syslog_facility = WebsockifySysLogHandler.LOG_USER syslog_facility = WebsockifySysLogHandler.LOG_USER
# Start logging to syslog. # Start logging to syslog.
syslog_handler = WebsockifySysLogHandler(address=syslog_dest, syslog_handler = WebsockifySysLogHandler(
address=syslog_dest,
facility=syslog_facility, facility=syslog_facility,
ident='websockify', ident="websockify",
legacy=opts.legacy_syslog) legacy=opts.legacy_syslog,
)
syslog_handler.setLevel(logging.DEBUG) syslog_handler.setLevel(logging.DEBUG)
syslog_handler.setFormatter(log_formatter) syslog_handler.setFormatter(log_formatter)
root = logging.getLogger() root = logging.getLogger()
@ -638,24 +772,23 @@ def websockify_init():
root = logging.getLogger() root = logging.getLogger()
root.setLevel(logging.DEBUG) root.setLevel(logging.DEBUG)
# Transform to absolute path as daemon may chdir # Transform to absolute path as daemon may chdir
if opts.target_cfg: if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg) opts.target_cfg = os.path.abspath(opts.target_cfg)
if opts.target_cfg: if opts.target_cfg:
opts.token_plugin = 'TokenFile' opts.token_plugin = "TokenFile"
opts.token_source = opts.target_cfg opts.token_source = opts.target_cfg
del opts.target_cfg del opts.target_cfg
if sys.argv.count('--'): if sys.argv.count("--"):
opts.wrap_cmd = args[1:] opts.wrap_cmd = args[1:]
else: else:
opts.wrap_cmd = None opts.wrap_cmd = None
if not websockifyserver.ssl and opts.ssl_target: if not websockifyserver.ssl and opts.ssl_target:
parser.error("SSL target requested and Python SSL module not loaded."); parser.error("SSL target requested and Python SSL module not loaded.")
if opts.ssl_only and not os.path.exists(opts.cert): if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert) parser.error("SSL only and %s not found" % opts.cert)
@ -677,11 +810,11 @@ def websockify_init():
parser.error("Too few arguments") parser.error("Too few arguments")
arg = args.pop(0) arg = args.pop(0)
# Parse host:port and convert ports to numbers # Parse host:port and convert ports to numbers
if arg.count(':') > 0: if arg.count(":") > 0:
opts.listen_host, opts.listen_port = arg.rsplit(':', 1) opts.listen_host, opts.listen_port = arg.rsplit(":", 1)
opts.listen_host = opts.listen_host.strip('[]') opts.listen_host = opts.listen_host.strip("[]")
else: else:
opts.listen_host, opts.listen_port = '', arg opts.listen_host, opts.listen_port = "", arg
try: try:
opts.listen_port = int(opts.listen_port) opts.listen_port = int(opts.listen_port)
@ -697,9 +830,9 @@ def websockify_init():
if len(args) < 1: if len(args) < 1:
parser.error("Too few arguments") parser.error("Too few arguments")
arg = args.pop(0) arg = args.pop(0)
if arg.count(':') > 0: if arg.count(":") > 0:
opts.target_host, opts.target_port = arg.rsplit(':', 1) opts.target_host, opts.target_port = arg.rsplit(":", 1)
opts.target_host = opts.target_host.strip('[]') opts.target_host = opts.target_host.strip("[]")
else: else:
parser.error("Error parsing target") parser.error("Error parsing target")
@ -712,11 +845,10 @@ def websockify_init():
parser.error("Too many arguments") parser.error("Too many arguments")
if opts.token_plugin is not None: if opts.token_plugin is not None:
if '.' not in opts.token_plugin: if "." not in opts.token_plugin:
opts.token_plugin = ( opts.token_plugin = "websockify.token_plugins.%s" % opts.token_plugin
'websockify.token_plugins.%s' % opts.token_plugin)
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1) token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit(".", 1)
__import__(token_plugin_module) __import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls) token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
@ -726,10 +858,10 @@ def websockify_init():
del opts.token_source del opts.token_source
if opts.auth_plugin is not None: if opts.auth_plugin is not None:
if '.' not in opts.auth_plugin: if "." not in opts.auth_plugin:
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin opts.auth_plugin = "websockify.auth_plugins.%s" % opts.auth_plugin
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1) auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit(".", 1)
__import__(auth_plugin_module) __import__(auth_plugin_module)
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls) auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
@ -759,32 +891,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options # Save off proxy specific options
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None) self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop('wrap_mode', None) self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop('unix_target', None) self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop("ssl_target", None)
self.token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop("token_plugin", None)
self.auth_plugin = kwargs.pop('auth_plugin', None) self.auth_plugin = kwargs.pop("auth_plugin", None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = None self.token_plugin = None
self.auth_plugin = None self.auth_plugin = None
self.daemon = False self.daemon = False
# Server configuration # Server configuration
listen_host = kwargs.pop('listen_host', '') listen_host = kwargs.pop("listen_host", "")
listen_port = kwargs.pop('listen_port', None) listen_port = kwargs.pop("listen_port", None)
web = kwargs.pop('web', '') web = kwargs.pop("web", "")
# Configuration affecting base request handler # Configuration affecting base request handler
self.only_upgrade = not web self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False) self.verbose = kwargs.pop("verbose", False)
record = kwargs.pop('record', '') record = kwargs.pop("record", "")
if record: if record:
self.record = os.path.abspath(record) self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False) self.run_once = kwargs.pop("run_once", False)
self.handler_id = 0 self.handler_id = 0
for arg in kwargs.keys(): for arg in kwargs.keys():
@ -795,12 +927,11 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
super().__init__((listen_host, listen_port), RequestHandlerClass) super().__init__((listen_host, listen_port), RequestHandlerClass)
def process_request(self, request, client_address): def process_request(self, request, client_address):
"""Override process_request to implement a counter""" """Override process_request to implement a counter"""
self.handler_id += 1 self.handler_id += 1
super().process_request(request, client_address) super().process_request(request, client_address)
if __name__ == '__main__': if __name__ == "__main__":
websockify_init() websockify_init()

View File

@ -1,19 +1,25 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Python WebSocket server base Python WebSocket server base
Copyright 2011 Joel Martin Copyright 2011 Joel Martin
Copyright 2016-2018 Pierre Ossman Copyright 2016-2018 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
''' """
import sys import sys
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError from websockify.websocket import (
WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
class HttpWebSocket(WebSocket): class HttpWebSocket(WebSocket):
"""Class to glue websocket and http request functionality together""" """Class to glue websocket and http request functionality together"""
def __init__(self, request_handler): def __init__(self, request_handler):
super().__init__() super().__init__()
@ -62,8 +68,10 @@ class WebSocketRequestHandlerMixIn:
# Checks if it is a websocket request and redirects # Checks if it is a websocket request and redirects
self.do_GET = self._real_do_GET self.do_GET = self._real_do_GET
if (self.headers.get('upgrade') and if (
self.headers.get('upgrade').lower() == 'websocket'): self.headers.get("upgrade")
and self.headers.get("upgrade").lower() == "websocket"
):
self.handle_upgrade() self.handle_upgrade()
else: else:
self.do_GET() self.do_GET()
@ -100,11 +108,13 @@ class WebSocketRequestHandlerMixIn:
""" """
pass pass
# Convenient ready made classes # Convenient ready made classes
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
BaseHTTPRequestHandler): class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler):
pass pass
class WebSocketServer(HTTPServer): class WebSocketServer(HTTPServer):
pass pass

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Python WebSocket server base with support for "wss://" encryption. Python WebSocket server base with support for "wss://" encryption.
Copyright 2011 Joel Martin Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman Copyright 2016 Pierre Ossman
@ -10,35 +10,39 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' """
import os, sys, time, errno, signal, socket, select, logging import os, sys, time, errno, signal, socket, select, logging
import multiprocessing import multiprocessing
from http.server import SimpleHTTPRequestHandler from http.server import SimpleHTTPRequestHandler
# Degraded functionality if these imports are missing # Degraded functionality if these imports are missing
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'), for mod, msg in [
('resource', 'daemonizing is disabled')]: ("ssl", "TLS/SSL/wss is disabled"),
("resource", "daemonizing is disabled"),
]:
try: try:
globals()[mod] = __import__(mod) globals()[mod] = __import__(mod)
except ImportError: except ImportError:
globals()[mod] = None globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg)) print("WARNING: no '%s' module, %s" % (mod, msg))
if sys.platform == 'win32': if sys.platform == "win32":
# make sockets pickle-able/inheritable # make sockets pickle-able/inheritable
import multiprocessing.reduction import multiprocessing.reduction
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocketserver import WebSocketRequestHandlerMixIn from websockify.websocketserver import WebSocketRequestHandlerMixIn
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
def select_subprotocol(self, protocols): def select_subprotocol(self, protocols):
# Handle old websockify clients that still specify a sub-protocol # Handle old websockify clients that still specify a sub-protocol
if 'binary' in protocols: if "binary" in protocols:
return 'binary' return "binary"
else: else:
return '' return ""
# HTTP handler with WebSocket upgrade support # HTTP handler with WebSocket upgrade support
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler): class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
@ -56,6 +60,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
* run_once: Handle a single request * run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename * handler_id: A sequence number for this connection, appended to record filename
""" """
server_version = "WebSockify" server_version = "WebSockify"
protocol_version = "HTTP/1.1" protocol_version = "HTTP/1.1"
@ -87,30 +92,33 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
super().__init__(req, addr, server) super().__init__(req, addr, server)
def log_message(self, format, *args): def log_message(self, format, *args):
self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args)) self.logger.info(
"%s - - [%s] %s"
% (self.client_address[0], self.log_date_time_string(), format % args)
)
# #
# WebSocketRequestHandler logging/output functions # WebSocketRequestHandler logging/output functions
# #
def print_traffic(self, token="."): def print_traffic(self, token="."):
""" Show traffic flow mode. """ """Show traffic flow mode."""
if self.traffic: if self.traffic:
sys.stdout.write(token) sys.stdout.write(token)
sys.stdout.flush() sys.stdout.flush()
def msg(self, msg, *args, **kwargs): def msg(self, msg, *args, **kwargs):
""" Output message with handler_id prefix. """ """Output message with handler_id prefix."""
prefix = "% 3d: " % self.handler_id prefix = "% 3d: " % self.handler_id
self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
def vmsg(self, msg, *args, **kwargs): def vmsg(self, msg, *args, **kwargs):
""" Same as msg() but as debug. """ """Same as msg() but as debug."""
prefix = "% 3d: " % self.handler_id prefix = "% 3d: " % self.handler_id
self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs) self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
def warn(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs):
""" Same as msg() but as warning. """ """Same as msg() but as warning."""
prefix = "% 3d: " % self.handler_id prefix = "% 3d: " % self.handler_id
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
@ -118,19 +126,24 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
# Main WebSocketRequestHandler methods # Main WebSocketRequestHandler methods
# #
def send_frames(self, bufs=None): def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already """Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued queued will be sent first. If buf is not set then only queued
frames will be sent. Returns True if any frames could not be frames will be sent. Returns True if any frames could not be
fully sent, in which case the caller should call again when fully sent, in which case the caller should call again when
the socket is ready. """ the socket is ready."""
tdelta = int(time.time()*1000) - self.start_time tdelta = int(time.time() * 1000) - self.start_time
if bufs: if bufs:
for buf in bufs: for buf in bufs:
if self.rec: if self.rec:
# Python 3 compatible conversion # Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'") bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr)) self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr))
self.send_parts.append(buf) self.send_parts.append(buf)
@ -147,7 +160,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return False return False
def recv_frames(self): def recv_frames(self):
""" Receive and decode WebSocket frames. """Receive and decode WebSocket frames.
Returns: Returns:
(bufs_list, closed_string) (bufs_list, closed_string)
@ -155,7 +168,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
closed = False closed = False
bufs = [] bufs = []
tdelta = int(time.time()*1000) - self.start_time tdelta = int(time.time() * 1000) - self.start_time
while True: while True:
try: try:
@ -165,15 +178,22 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
break break
if buf is None: if buf is None:
closed = {'code': self.request.close_code, closed = {
'reason': self.request.close_reason} "code": self.request.close_code,
"reason": self.request.close_reason,
}
return bufs, closed return bufs, closed
self.print_traffic("}") self.print_traffic("}")
if self.rec: if self.rec:
# Python 3 compatible conversion # Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'") bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr)) self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr))
bufs.append(buf) bufs.append(buf)
@ -183,16 +203,16 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return bufs, closed return bufs, closed
def send_close(self, code=1000, reason=''): def send_close(self, code=1000, reason=""):
""" Send a WebSocket orderly close frame. """ """Send a WebSocket orderly close frame."""
self.request.shutdown(socket.SHUT_RDWR, code, reason) self.request.shutdown(socket.SHUT_RDWR, code, reason)
def send_pong(self, data=b''): def send_pong(self, data=b""):
""" Send a WebSocket pong frame. """ """Send a WebSocket pong frame."""
self.request.pong(data) self.request.pong(data)
def send_ping(self, data=b''): def send_ping(self, data=b""):
""" Send a WebSocket ping frame. """ """Send a WebSocket ping frame."""
self.request.ping(data) self.request.ping(data)
def handle_upgrade(self): def handle_upgrade(self):
@ -208,7 +228,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
# Initialize per client settings # Initialize per client settings
self.send_parts = [] self.send_parts = []
self.recv_part = None self.recv_part = None
self.start_time = int(time.time()*1000) self.start_time = int(time.time() * 1000)
# client_address is empty with, say, UNIX domain sockets # client_address is empty with, say, UNIX domain sockets
client_addr = "" client_addr = ""
@ -224,17 +244,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else: else:
self.stype = "Plain non-SSL (ws://)" self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection", client_addr, self.log_message("%s: %s WebSocket connection", client_addr, self.stype)
self.stype) if self.path != "/":
if self.path != '/':
self.log_message("%s: Path: '%s'", client_addr, self.path) self.log_message("%s: Path: '%s'", client_addr, self.path)
if self.record: if self.record:
# Record raw frame data as JavaScript array # Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record, fname = "%s.%s" % (self.record, self.handler_id)
self.handler_id)
self.log_message("opening record file: %s", fname) self.log_message("opening record file: %s", fname)
self.rec = open(fname, 'w+') self.rec = open(fname, "w+")
self.rec.write("var VNC_frame_data = [\n") self.rec.write("var VNC_frame_data = [\n")
try: try:
@ -261,15 +279,17 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return super().list_directory(path) return super().list_directory(path)
def new_websocket_client(self): def new_websocket_client(self):
""" Do something with a WebSockets client connection. """ """Do something with a WebSockets client connection."""
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") raise Exception(
"WebSocketRequestHandler.new_websocket_client() must be overloaded"
)
def validate_connection(self): def validate_connection(self):
""" Ensure that the connection has a valid token, and set the target. """ """Ensure that the connection has a valid token, and set the target."""
pass pass
def auth_connection(self): def auth_connection(self):
""" Ensure that the connection is authorized. """ """Ensure that the connection is authorized."""
pass pass
def do_HEAD(self): def do_HEAD(self):
@ -296,12 +316,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else: else:
super().handle() super().handle()
def log_request(self, code='-', size='-'): def log_request(self, code="-", size="-"):
if self.verbose: if self.verbose:
super().log_request(code, size) super().log_request(code, size)
class WebSockifyServer(): class WebSockifyServer:
""" """
WebSockets server class. WebSockets server class.
As an alternative, the standard library SocketServer can be used As an alternative, the standard library SocketServer can be used
@ -317,17 +337,38 @@ class WebSockifyServer():
class Terminate(Exception): class Terminate(Exception):
pass pass
def __init__(self, RequestHandlerClass, listen_fd=None, def __init__(
listen_host='', listen_port=None, source_is_ipv6=False, self,
verbose=False, cert='', key='', key_password=None, ssl_only=None, RequestHandlerClass,
verify_client=False, cafile=None, listen_fd=None,
daemon=False, record='', web='', web_auth=False, listen_host="",
listen_port=None,
source_is_ipv6=False,
verbose=False,
cert="",
key="",
key_password=None,
ssl_only=None,
verify_client=False,
cafile=None,
daemon=False,
record="",
web="",
web_auth=False,
file_only=False, file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False, run_once=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, timeout=0,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0, idle_timeout=0,
unix_listen=None, unix_listen_mode=None): traffic=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
ssl_ciphers=None,
ssl_options=0,
unix_listen=None,
unix_listen_mode=None,
):
# settings # settings
self.RequestHandlerClass = RequestHandlerClass self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose self.verbose = verbose
@ -366,7 +407,7 @@ class WebSockifyServer():
# Make paths settings absolute # Make paths settings absolute
self.cert = os.path.abspath(cert) self.cert = os.path.abspath(cert)
self.web = self.record = self.cafile = '' self.web = self.record = self.cafile = ""
if key: if key:
self.key = os.path.abspath(key) self.key = os.path.abspath(key)
if web: if web:
@ -393,11 +434,12 @@ class WebSockifyServer():
elif self.unix_listen != None: elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen) self.msg(" - Listen on unix socket %s", self.unix_listen)
else: else:
self.msg(" - Listen on %s:%s", self.msg(" - Listen on %s:%s", self.listen_host, self.listen_port)
self.listen_host, self.listen_port)
if self.web: if self.web:
if self.file_only: if self.file_only:
self.msg(" - Web server (no directory listings). Web root: %s", self.web) self.msg(
" - Web server (no directory listings). Web root: %s", self.web
)
else: else:
self.msg(" - Web server. Web root: %s", self.web) self.msg(" - Web server. Web root: %s", self.web)
if ssl: if ssl:
@ -420,34 +462,45 @@ class WebSockifyServer():
@staticmethod @staticmethod
def get_logger(): def get_logger():
return logging.getLogger("%s.%s" % ( return logging.getLogger(
WebSockifyServer.log_prefix, "%s.%s" % (WebSockifyServer.log_prefix, WebSockifyServer.__class__.__name__)
WebSockifyServer.__class__.__name__)) )
@staticmethod @staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False, def socket(
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False, host,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None, port=None,
tcp_keepidle=None, tcp_keepintvl=None): connect=False,
""" Resolve a host (and optional port) to an IPv4 or IPv6 prefer_ipv6=False,
unix_socket=None,
unix_socket_mode=None,
unix_socket_listen=False,
use_ssl=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
):
"""Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set, address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket. otherwise connect to it. Return the socket.
""" """
flags = 0 flags = 0
if host == '': if host == "":
host = None host = None
if connect and not (port or unix_socket): if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port") raise Exception("Connect mode requires a port")
if use_ssl and not ssl: if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded."); raise Exception("SSL socket requested but Python SSL module not loaded.")
if not connect and use_ssl: if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)") raise Exception("SSL only supported in connect mode (for now)")
if not connect: if not connect:
flags = flags | socket.AI_PASSIVE flags = flags | socket.AI_PASSIVE
if not unix_socket: if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, addrs = socket.getaddrinfo(
socket.IPPROTO_TCP, flags) host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags
)
if not addrs: if not addrs:
raise Exception("Could not resolve host '%s'" % host) raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0]) addrs.sort(key=lambda x: x[0])
@ -458,14 +511,11 @@ class WebSockifyServer():
if tcp_keepalive: if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt: if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
tcp_keepcnt)
if tcp_keepidle: if tcp_keepidle:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
tcp_keepidle)
if tcp_keepintvl: if tcp_keepintvl:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
tcp_keepintvl)
if connect: if connect:
sock.connect(addrs[0][4]) sock.connect(addrs[0][4])
@ -497,8 +547,7 @@ class WebSockifyServer():
return sock return sock
@staticmethod @staticmethod
def daemonize(keepfd=None, chdir='/'): def daemonize(keepfd=None, chdir="/"):
if keepfd is None: if keepfd is None:
keepfd = [] keepfd = []
@ -506,14 +555,16 @@ class WebSockifyServer():
if chdir: if chdir:
os.chdir(chdir) os.chdir(chdir)
else: else:
os.chdir('/') os.chdir("/")
os.setgid(os.getgid()) # relinquish elevations os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize # Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits if os.fork() > 0:
os._exit(0) # Parent exits
os.setsid() # Obtain new process group os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits if os.fork() > 0:
os._exit(0) # Parent exits
# Signal handling # Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN) signal.signal(signal.SIGTERM, signal.SIG_IGN)
@ -521,14 +572,16 @@ class WebSockifyServer():
# Close open files # Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256 if maxfd == resource.RLIM_INFINITY:
maxfd = 256
for fd in reversed(range(maxfd)): for fd in reversed(range(maxfd)):
try: try:
if fd not in keepfd: if fd not in keepfd:
os.close(fd) os.close(fd)
except OSError: except OSError:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise if exc.errno != errno.EBADF:
raise
# Redirect I/O to /dev/null # Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
@ -557,7 +610,7 @@ class WebSockifyServer():
# Peek, but do not read the data so that we have a opportunity # Peek, but do not read the data so that we have a opportunity
# to SSL wrap the socket first # to SSL wrap the socket first
handshake = sock.recv(1024, socket.MSG_PEEK) handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake) # self.msg("Handshake [%s]" % handshake)
if not handshake: if not handshake:
raise self.EClose("") raise self.EClose("")
@ -567,8 +620,7 @@ class WebSockifyServer():
if not ssl: if not ssl:
raise self.EClose("SSL connection but no 'ssl' module") raise self.EClose("SSL connection but no 'ssl' module")
if not os.path.exists(self.cert): if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found" raise self.EClose("SSL connection but '%s' not found" % self.cert)
% self.cert)
retsock = None retsock = None
try: try:
# create new-style SSL wrapping for extended features # create new-style SSL wrapping for extended features
@ -576,16 +628,16 @@ class WebSockifyServer():
if self.ssl_ciphers is not None: if self.ssl_ciphers is not None:
context.set_ciphers(self.ssl_ciphers) context.set_ciphers(self.ssl_ciphers)
context.options = self.ssl_options context.options = self.ssl_options
context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password) context.load_cert_chain(
certfile=self.cert, keyfile=self.key, password=self.key_password
)
if self.verify_client: if self.verify_client:
context.verify_mode = ssl.CERT_REQUIRED context.verify_mode = ssl.CERT_REQUIRED
if self.cafile: if self.cafile:
context.load_verify_locations(cafile=self.cafile) context.load_verify_locations(cafile=self.cafile)
else: else:
context.set_default_verify_paths() context.set_default_verify_paths()
retsock = context.wrap_socket( retsock = context.wrap_socket(sock, server_side=True)
sock,
server_side=True)
except ssl.SSLError: except ssl.SSLError:
_, x, _ = sys.exc_info() _, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF: if x.args[0] == ssl.SSL_ERROR_EOF:
@ -618,28 +670,27 @@ class WebSockifyServer():
# #
def msg(self, *args, **kwargs): def msg(self, *args, **kwargs):
""" Output message as info """ """Output message as info"""
self.logger.log(logging.INFO, *args, **kwargs) self.logger.log(logging.INFO, *args, **kwargs)
def vmsg(self, *args, **kwargs): def vmsg(self, *args, **kwargs):
""" Same as msg() but as debug. """ """Same as msg() but as debug."""
self.logger.log(logging.DEBUG, *args, **kwargs) self.logger.log(logging.DEBUG, *args, **kwargs)
def warn(self, *args, **kwargs): def warn(self, *args, **kwargs):
""" Same as msg() but as warning. """ """Same as msg() but as warning."""
self.logger.log(logging.WARN, *args, **kwargs) self.logger.log(logging.WARN, *args, **kwargs)
# #
# Events that can/should be overridden in sub-classes # Events that can/should be overridden in sub-classes
# #
def started(self): def started(self):
""" Called after WebSockets startup """ """Called after WebSockets startup"""
self.vmsg("WebSockets server started") self.vmsg("WebSockets server started")
def poll(self): def poll(self):
""" Run periodically while waiting for connections. """ """Run periodically while waiting for connections."""
#self.vmsg("Running poll()") # self.vmsg("Running poll()")
pass pass
def terminate(self): def terminate(self):
@ -661,7 +712,7 @@ class WebSockifyServer():
while result[0]: while result[0]:
self.vmsg("Reaped child process %s" % result[0]) self.vmsg("Reaped child process %s" % result[0])
result = os.waitpid(-1, os.WNOHANG) result = os.waitpid(-1, os.WNOHANG)
except (OSError): except OSError:
pass pass
def do_SIGINT(self, sig, stack): def do_SIGINT(self, sig, stack):
@ -675,7 +726,7 @@ class WebSockifyServer():
self.terminate() self.terminate()
def top_new_client(self, startsock, address): def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """ """Do something with a WebSockets client connection."""
# handler process # handler process
client = None client = None
try: try:
@ -693,7 +744,6 @@ class WebSockifyServer():
self.msg("handler exception: %s" % str(exc)) self.msg("handler exception: %s" % str(exc))
self.vmsg("exception", exc_info=True) self.vmsg("exception", exc_info=True)
finally: finally:
if client and client != startsock: if client and client != startsock:
# Close the SSL wrapped socket # Close the SSL wrapped socket
# Original socket closed by caller # Original socket closed by caller
@ -721,19 +771,27 @@ class WebSockifyServer():
try: try:
if self.listen_fd != None: if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM) lsock = socket.fromfd(
self.listen_fd, socket.AF_INET, socket.SOCK_STREAM
)
elif self.unix_listen != None: elif self.unix_listen != None:
lsock = self.socket(host=None, lsock = self.socket(
host=None,
unix_socket=self.unix_listen, unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode, unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True) unix_socket_listen=True,
)
else: else:
lsock = self.socket(self.listen_host, self.listen_port, False, lsock = self.socket(
self.listen_host,
self.listen_port,
False,
self.prefer_ipv6, self.prefer_ipv6,
tcp_keepalive=self.tcp_keepalive, tcp_keepalive=self.tcp_keepalive,
tcp_keepcnt=self.tcp_keepcnt, tcp_keepcnt=self.tcp_keepcnt,
tcp_keepidle=self.tcp_keepidle, tcp_keepidle=self.tcp_keepidle,
tcp_keepintvl=self.tcp_keepintvl) tcp_keepintvl=self.tcp_keepintvl,
)
except OSError as e: except OSError as e:
self.msg("Openening socket failed: %s", str(e)) self.msg("Openening socket failed: %s", str(e))
self.vmsg("exception", exc_info=True) self.vmsg("exception", exc_info=True)
@ -751,13 +809,13 @@ class WebSockifyServer():
signal.SIGINT: signal.getsignal(signal.SIGINT), signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM), signal.SIGTERM: signal.getsignal(signal.SIGTERM),
} }
if getattr(signal, 'SIGCHLD', None) is not None: if getattr(signal, "SIGCHLD", None) is not None:
original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD) original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM) signal.signal(signal.SIGTERM, self.do_SIGTERM)
# make sure that _cleanup is called when children die # make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD # by calling active_children on SIGCHLD
if getattr(signal, 'SIGCHLD', None) is not None: if getattr(signal, "SIGCHLD", None) is not None:
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time last_active_time = self.launch_time
@ -774,8 +832,7 @@ class WebSockifyServer():
time_elapsed = time.time() - self.launch_time time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout: if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s' self.msg("listener exit due to --timeout %s" % self.timeout)
% self.timeout)
break break
if self.idle_timeout: if self.idle_timeout:
@ -787,8 +844,10 @@ class WebSockifyServer():
last_active_time = time.time() last_active_time = time.time()
if idle_time > self.idle_timeout and child_count == 0: if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s' self.msg(
% self.idle_timeout) "listener exit due to --idle-timeout %s"
% self.idle_timeout
)
break break
try: try:
@ -799,16 +858,16 @@ class WebSockifyServer():
startsock, address = lsock.accept() startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch # Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None: if self.unix_listen != None:
address = [ self.unix_listen ] address = [self.unix_listen]
else: else:
continue continue
except self.Terminate: except self.Terminate:
raise raise
except Exception: except Exception:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'): if hasattr(exc, "errno"):
err = exc.errno err = exc.errno
elif hasattr(exc, 'args'): elif hasattr(exc, "args"):
err = exc.args[0] err = exc.args[0]
else: else:
err = exc[0] err = exc[0]
@ -821,15 +880,14 @@ class WebSockifyServer():
if self.run_once: if self.run_once:
# Run in same process if run_once # Run in same process if run_once
self.top_new_client(startsock, address) self.top_new_client(startsock, address)
if self.ws_connection : if self.ws_connection:
self.msg('%s: exiting due to --run-once' self.msg("%s: exiting due to --run-once" % address[0])
% address[0])
break break
else: else:
self.vmsg('%s: new handler Process' % address[0]) self.vmsg("%s: new handler Process" % address[0])
p = multiprocessing.Process( p = multiprocessing.Process(
target=self.top_new_client, target=self.top_new_client, args=(startsock, address)
args=(startsock, address)) )
p.start() p.start()
# child will not return # child will not return
@ -857,12 +915,11 @@ class WebSockifyServer():
startsock.close() startsock.close()
finally: finally:
# Close listen port # Close listen port
self.vmsg("Closing socket listening at %s:%s", self.vmsg(
self.listen_host, self.listen_port) "Closing socket listening at %s:%s", self.listen_host, self.listen_port
)
lsock.close() lsock.close()
# Restore signals # Restore signals
for sig, func in original_signals.items(): for sig, func in original_signals.items():
signal.signal(sig, func) signal.signal(sig, func)