run `ruff format .`

This commit is contained in:
Doctor 2025-03-12 20:55:57 +03:00
parent 4fc7e52352
commit 3eca8ad195
20 changed files with 1488 additions and 1027 deletions

View File

@ -1,11 +1,11 @@
from setuptools import setup, find_packages
version = '0.13.0'
name = 'websockify'
long_description = open("README.md").read() + "\n" + \
open("CHANGES.txt").read() + "\n"
version = "0.13.0"
name = "websockify"
long_description = open("README.md").read() + "\n" + open("CHANGES.txt").read() + "\n"
setup(name=name,
setup(
name=name,
version=version,
description="Websockify.",
long_description=long_description,
@ -22,23 +22,23 @@ setup(name=name,
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
keywords='noVNC websockify',
license='LGPLv3',
keywords="noVNC websockify",
license="LGPLv3",
url="https://github.com/novnc/websockify",
author="Joel Martin",
author_email="github@martintribe.org",
packages=['websockify'],
packages=["websockify"],
include_package_data=True,
install_requires=[
'numpy', 'requests',
'jwcrypto',
'redis',
"numpy",
"requests",
"jwcrypto",
"redis",
],
zip_safe=False,
entry_points={
'console_scripts': [
'websockify = websockify.websocketproxy:websockify_init',
"console_scripts": [
"websockify = websockify.websocketproxy:websockify_init",
]
},
)
)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
A WebSocket server that echos back whatever it receives from the client.
Copyright 2010 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -8,16 +8,19 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import os, sys, select, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketEcho(WebSockifyRequestHandler):
"""
WebSockets server that echos back whatever is received from the
client. """
client."""
buffer_size = 8096
def new_websocket_client(self):
@ -33,9 +36,11 @@ class WebSocketEcho(WebSockifyRequestHandler):
while True:
wlist = []
if cqueue or c_pend: wlist.append(self.request)
if cqueue or c_pend:
wlist.append(self.request)
ins, outs, excepts = select.select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if self.request in outs:
# Send queued target data to the client
@ -50,20 +55,27 @@ class WebSocketEcho(WebSockifyRequestHandler):
if closed:
break
if __name__ == '__main__':
if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
parser.add_option(
"--verbose",
"-v",
action="store_true",
help="verbose messages and per frame traffic",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args()
try:
if len(args) != 1: raise ValueError
if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0])
except ValueError:
parser.error("Invalid arguments")
@ -73,4 +85,3 @@ if __name__ == '__main__':
opts.web = "."
server = WebSockifyServer(WebSocketEcho, **opts.__dict__)
server.start_server()

View File

@ -5,9 +5,12 @@ import sys
import optparse
import select
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocket, \
WebSocketWantReadError, WebSocketWantWriteError
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import (
WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
parser = optparse.OptionParser(usage="%prog URL")
(opts, args) = parser.parse_args()
@ -22,19 +25,23 @@ print("Connecting to %s..." % URL)
sock.connect(URL)
print("Connected.")
def send(msg):
while True:
try:
sock.sendmsg(msg)
break
except WebSocketWantReadError:
msg = ''
msg = ""
ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError:
msg = ''
msg = ""
ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
def read():
while True:
@ -42,10 +49,13 @@ def read():
return sock.recvmsg()
except WebSocketWantReadError:
ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError:
ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
counter = 1
while True:
@ -56,7 +66,8 @@ while True:
while True:
ins, outs, excepts = select.select([sock], [], [], 1.0)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if ins == []:
break

View File

@ -1,32 +1,32 @@
#!/usr/bin/env python
'''
"""
WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted.
'''
"""
import sys, os, select, random, time, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
class WebSocketLoadServer(WebSockifyServer):
class WebSocketLoadServer(WebSockifyServer):
recv_cnt = 0
send_cnt = 0
def __init__(self, *args, **kwargs):
self.delay = kwargs.pop('delay')
self.delay = kwargs.pop("delay")
WebSockifyServer.__init__(self, *args, **kwargs)
class WebSocketLoad(WebSockifyRequestHandler):
max_packet_size = 10000
def new_websocket_client(self):
print "Prepopulating random array"
print("Prepopulating random array")
self.rand_array = []
for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9))
@ -37,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
self.responder(self.request)
print "accumulated errors:", self.errors
print("accumulated errors:", self.errors)
self.errors = 0
def responder(self, client):
@ -49,7 +49,8 @@ class WebSocketLoad(WebSockifyRequestHandler):
while True:
ins, outs, excepts = select.select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if client in ins:
frames, closed = self.recv_frames()
@ -57,7 +58,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
err = self.check(frames)
if err:
self.errors = self.errors + 1
print err
print(err)
if closed:
break
@ -73,24 +74,22 @@ class WebSocketLoad(WebSockifyRequestHandler):
def generate(self):
length = random.randint(10, self.max_packet_size)
numlist = self.rand_array[self.max_packet_size-length:]
numlist = self.rand_array[self.max_packet_size - length :]
# Error in length
#numlist.append(5)
# numlist.append(5)
chksum = sum(numlist)
# Error in checksum
#numlist[0] = 5
nums = "".join( [str(n) for n in numlist] )
# numlist[0] = 5
nums = "".join([str(n) for n in numlist])
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
self.send_cnt += 1
return data
def check(self, frames):
err = ""
for data in frames:
if data.count('$') > 1:
if data.count("$") > 1:
raise Exception("Multiple parts within single packet")
if len(data) == 0:
self.traffic("_")
@ -101,12 +100,12 @@ class WebSocketLoad(WebSockifyRequestHandler):
continue
try:
cnt, length, chksum, nums = data[1:-1].split(':')
cnt, length, chksum, nums = data[1:-1].split(":")
cnt = int(cnt)
length = int(length)
chksum = int(chksum)
except ValueError:
print "\n<BOF>" + repr(data) + "<EOF>"
print("\n<BOF>" + repr(data) + "<EOF>")
err += "Invalid data format\n"
continue
@ -131,27 +130,37 @@ class WebSocketLoad(WebSockifyRequestHandler):
real_chksum += int(num)
if real_chksum != chksum:
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
err += "Expected checksum %d but real chksum is %d\n" % (
chksum,
real_chksum,
)
return err
if __name__ == '__main__':
if __name__ == "__main__":
parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
parser.add_option(
"--verbose",
"-v",
action="store_true",
help="verbose messages and per frame traffic",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option(
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
)
(opts, args) = parser.parse_args()
try:
if len(args) != 1: raise ValueError
if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0])
if len(args) not in [1,2]: raise ValueError
if len(args) not in [1, 2]:
raise ValueError
opts.listen_port = int(args[0])
if len(args) == 2:
opts.delay = int(args[1])
@ -165,4 +174,3 @@ if __name__ == '__main__':
opts.web = "."
server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
server.start_server()

View File

@ -1,28 +1,33 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
""" Unit tests for Authentication plugins"""
"""Unit tests for Authentication plugins"""
from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError
import unittest
class BasicHTTPAuthTestCase(unittest.TestCase):
def setUp(self):
self.plugin = BasicHTTPAuth('Aladdin:open sesame')
self.plugin = BasicHTTPAuth("Aladdin:open sesame")
def test_no_auth(self):
headers = {}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_invalid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
headers = {"Authorization": "Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0"}
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)
def test_valid_password(self):
headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='}
self.plugin.authenticate(headers, 'localhost', '1234')
headers = {"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="}
self.plugin.authenticate(headers, "localhost", "1234")
def test_garbage_auth(self):
headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'}
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
headers = {"Authorization": "Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
self.assertRaises(
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
)

View File

@ -1,80 +1,93 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
""" Unit tests for Token plugins"""
"""Unit tests for Token plugins"""
import sys
import unittest
from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
from websockify.token_plugins import (
parse_source_args,
ReadOnlyTokenFile,
JWTTokenApi,
TokenRedis,
)
class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self):
params = [
('', ['']),
(':', ['', '']),
('::', ['', '', '']),
("", [""]),
(":", ["", ""]),
("::", ["", "", ""]),
('"', ['"']),
('""', ['""']),
('"""', ['"""']),
('"localhost"', ['localhost']),
('"localhost":', ['localhost', '']),
('"localhost"::', ['localhost', '', '']),
('"local:host"', ['local:host']),
('"local:host:"pass"', ['"local', 'host', "pass"]),
('"local":"host"', ['local', 'host']),
('"local":host"', ['local', 'host"']),
('localhost:6379:1:pass"word:"my-app-namespace:dev"',
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']),
('"localhost"', ["localhost"]),
('"localhost":', ["localhost", ""]),
('"localhost"::', ["localhost", "", ""]),
('"local:host"', ["local:host"]),
('"local:host:"pass"', ['"local', "host", "pass"]),
('"local":"host"', ["local", "host"]),
('"local":host"', ["local", 'host"']),
(
'localhost:6379:1:pass"word:"my-app-namespace:dev"',
["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"],
),
]
for src, args in params:
self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_empty(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = ""
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNone(result)
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_simple(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = "testhost: remote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
patch('os.path.isdir', MagicMock(return_value=False))
patch("os.path.isdir", MagicMock(return_value=False))
def test_tabs(self):
plugin = ReadOnlyTokenFile('configfile')
plugin = ReadOnlyTokenFile("configfile")
config = "testhost:\tremote_host:remote_port"
pyopen = mock_open(read_data=config)
with patch("websockify.token_plugins.open", pyopen, create=True):
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
pyopen.assert_called_once_with('configfile')
pyopen.assert_called_once_with("configfile")
self.assertIsNotNone(result)
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class JWSTokenTestCase(unittest.TestCase):
def test_asymmetric_jws_token_plugin(self):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
@ -82,7 +95,9 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -97,21 +112,26 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
self.assertIsNone(result)
@patch('time.time')
@patch("time.time")
def test_jwt_valid_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 150
@ -121,14 +141,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('time.time')
@patch("time.time")
def test_jwt_early_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 50
@ -136,14 +159,17 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertIsNone(result)
@patch('time.time')
@patch("time.time")
def test_jwt_late_time(self, mock_time):
plugin = JWTTokenApi("./tests/fixtures/public.pem")
key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
jwt_token = jwt.JWT(
{"alg": "RS256"},
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
)
jwt_token.make_signed_token(key)
mock_time.return_value = 250
@ -156,8 +182,10 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct",k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -171,8 +199,10 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK()
key.import_key(kty="oct",k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT(
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(key)
result = plugin.lookup(jwt_token.serialize())
@ -188,10 +218,14 @@ class JWSTokenTestCase(unittest.TestCase):
public_key_data = open("./tests/fixtures/public.pem", "rb").read()
private_key.import_from_pem(private_key_data)
public_key.import_from_pem(public_key_data)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token = jwt.JWT(
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
)
jwt_token.make_signed_token(private_key)
jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize())
jwe_token = jwt.JWT(
header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize(),
)
jwe_token.make_encrypted_token(public_key)
result = plugin.lookup(jwt_token.serialize())
@ -200,103 +234,104 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
class TokenRedisTestCase(unittest.TestCase):
def setUp(self):
try:
import redis
except ImportError:
patcher = patch.dict(sys.modules, {'redis': MagicMock()})
patcher = patch.dict(sys.modules, {"redis": MagicMock()})
patcher.start()
self.addCleanup(patcher.stop)
@patch('redis.Redis')
@patch("redis.Redis")
def test_empty(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = None
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch('redis.Redis')
@patch("redis.Redis")
def test_simple(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port"}'
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_json_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b' {"host": "remote_host:remote_port"} '
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_text_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'remote_host:remote_port'
instance.get.return_value = b"remote_host:remote_port"
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_text_token_with_spaces(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b' remote_host:remote_port '
instance.get.return_value = b" remote_host:remote_port "
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_invalid_token(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
instance = mock_redis.return_value
instance.get.return_value = b'{"host": "remote_host:remote_port" '
result = plugin.lookup('testhost')
result = plugin.lookup("testhost")
instance.get.assert_called_once_with('testhost')
instance.get.assert_called_once_with("testhost")
self.assertIsNone(result)
@patch('redis.Redis')
@patch("redis.Redis")
def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
token = 'testhost'
plugin = TokenRedis("127.0.0.1:1234")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, token)
return b'remote_host:remote_port'
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
@ -304,17 +339,17 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
@patch('redis.Redis')
@patch("redis.Redis")
def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234:::namespace')
token = 'testhost'
plugin = TokenRedis("127.0.0.1:1234:::namespace")
token = "testhost"
def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token)
return b'remote_host:remote_port'
return b"remote_host:remote_port"
instance = mock_redis.return_value
instance.get = mock_redis_get
@ -322,103 +357,103 @@ class TokenRedisTestCase(unittest.TestCase):
result = plugin.lookup(token)
self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')
self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port")
def test_src_only_host(self):
plugin = TokenRedis('127.0.0.1')
plugin = TokenRedis("127.0.0.1")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port(self):
plugin = TokenRedis('127.0.0.1:1234')
plugin = TokenRedis("127.0.0.1:1234")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db(self):
plugin = TokenRedis('127.0.0.1:1234:2')
plugin = TokenRedis("127.0.0.1:1234:2")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret')
plugin = TokenRedis("127.0.0.1:1234:2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace')
plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "namespace:")
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::verysecret')
plugin = TokenRedis("127.0.0.1:::verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis('127.0.0.1::::')
plugin = TokenRedis("127.0.0.1::::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::')
plugin = TokenRedis("127.0.0.1:::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::')
plugin = TokenRedis("127.0.0.1::")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:')
plugin = TokenRedis("127.0.0.1:")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis('127.0.0.1::::namespace')
plugin = TokenRedis("127.0.0.1::::namespace")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
@ -427,43 +462,43 @@ class TokenRedisTestCase(unittest.TestCase):
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:")
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2')
plugin = TokenRedis("127.0.0.1::2")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:1234::verysecret')
plugin = TokenRedis("127.0.0.1:1234::verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:verysecret')
plugin = TokenRedis("127.0.0.1::2:verysecret")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._password, "verysecret")
self.assertEqual(plugin._namespace, "")
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:')
plugin = TokenRedis("127.0.0.1::2:")
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._server, "127.0.0.1")
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)

View File

@ -14,199 +14,268 @@
# License for the specific language governing permissions and limitations
# under the License.
""" Unit tests for websocket """
"""Unit tests for websocket"""
import unittest
from websockify import websocket
class FakeSocket:
def __init__(self):
self.data = b''
self.data = b""
def send(self, buf):
self.data += buf
return len(buf)
class AcceptTestCase(unittest.TestCase):
def test_success(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data)
self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data)
self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nUpgrade: websocket\r\n" in sock.data)
self.assertTrue(b"\r\nConnection: Upgrade\r\n" in sock.data)
self.assertTrue(
b"\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n" in sock.data
)
def test_bad_version(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '5',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '20',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(
Exception,
ws.accept,
sock,
{"upgrade": "websocket", "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q=="},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "5",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "20",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_bad_upgrade(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket2',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket2",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
def test_missing_key(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13'})
self.assertRaises(
Exception,
ws.accept,
sock,
{"upgrade": "websocket", "Sec-WebSocket-Version": "13"},
)
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'gazonk'
return "gazonk"
ws = ProtoSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertTrue(b"\r\nSec-WebSocket-Protocol: gazonk\r\n" in sock.data)
def test_no_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data)
ws.accept(
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
self.assertFalse(b"\r\nSec-WebSocket-Protocol:" in sock.data)
def test_missing_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'oddball'
return "oddball"
ws = ProtoSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertRaises(
Exception,
ws.accept,
sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
"Sec-WebSocket-Protocol": "foobar gazonk",
},
)
class PingPongTest(unittest.TestCase):
def setUp(self):
self.ws = websocket.WebSocket()
self.sock = FakeSocket()
self.ws.accept(self.sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ')
self.sock.data = b''
self.ws.accept(
self.sock,
{
"upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
},
)
self.assertEqual(self.sock.data[:13], b"HTTP/1.1 101 ")
self.sock.data = b""
def test_ping(self):
self.ws.ping()
self.assertEqual(self.sock.data, b'\x89\x00')
self.assertEqual(self.sock.data, b"\x89\x00")
def test_pong(self):
self.ws.pong()
self.assertEqual(self.sock.data, b'\x8a\x00')
self.assertEqual(self.sock.data, b"\x8a\x00")
def test_ping_data(self):
self.ws.ping(b'foo')
self.assertEqual(self.sock.data, b'\x89\x03foo')
self.ws.ping(b"foo")
self.assertEqual(self.sock.data, b"\x89\x03foo")
def test_pong_data(self):
self.ws.pong(b'foo')
self.assertEqual(self.sock.data, b'\x8a\x03foo')
self.ws.pong(b"foo")
self.assertEqual(self.sock.data, b"\x8a\x03foo")
class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
buf = b"\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58"
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x1)
self.assertEqual(res['masked'], True)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'Hello')
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x1)
self.assertEqual(res["masked"], True)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], b"Hello")
def test_decode_hybi_binary(self):
buf = b'\x82\x04\x01\x02\x03\x04'
buf = b"\x82\x04\x01\x02\x03\x04"
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], b"\x01\x02\x03\x04")
def test_decode_hybi_extended_16bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7e\x01\x04' + data
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b"\x82\x7e\x01\x04" + data
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], data)
def test_decode_hybi_extended_64bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
buf = b"\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04" + data
ws = websocket.WebSocket()
res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
self.assertEqual(res["fin"], 1)
self.assertEqual(res["opcode"], 0x2)
self.assertEqual(res["length"], len(buf))
self.assertEqual(res["payload"], data)
def test_decode_hybi_multi(self):
buf1 = b'\x01\x03\x48\x65\x6c'
buf2 = b'\x80\x02\x6c\x6f'
buf1 = b"\x01\x03\x48\x65\x6c"
buf2 = b"\x80\x02\x6c\x6f"
ws = websocket.WebSocket()
res1 = ws._decode_hybi(buf1)
self.assertEqual(res1['fin'], 0)
self.assertEqual(res1['opcode'], 0x1)
self.assertEqual(res1['length'], len(buf1))
self.assertEqual(res1['payload'], b'Hel')
self.assertEqual(res1["fin"], 0)
self.assertEqual(res1["opcode"], 0x1)
self.assertEqual(res1["length"], len(buf1))
self.assertEqual(res1["payload"], b"Hel")
res2 = ws._decode_hybi(buf2)
self.assertEqual(res2['fin'], 1)
self.assertEqual(res2['opcode'], 0x0)
self.assertEqual(res2['length'], len(buf2))
self.assertEqual(res2['payload'], b'lo')
self.assertEqual(res2["fin"], 1)
self.assertEqual(res2["opcode"], 0x0)
self.assertEqual(res2["length"], len(buf2))
self.assertEqual(res2["payload"], b"lo")
def test_encode_hybi_basic(self):
ws = websocket.WebSocket()
res = ws._encode_hybi(0x1, b'Hello')
expected = b'\x81\x05\x48\x65\x6c\x6c\x6f'
res = ws._encode_hybi(0x1, b"Hello")
expected = b"\x81\x05\x48\x65\x6c\x6c\x6f"
self.assertEqual(res, expected)

View File

@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
""" Unit tests for websocketproxy """
"""Unit tests for websocketproxy"""
import sys
import unittest
@ -30,7 +30,7 @@ from websockify import auth_plugins
class FakeSocket:
def __init__(self, data=b''):
def __init__(self, data=b""):
self._data = data
def recv(self, amt, flags=None):
@ -40,11 +40,11 @@ class FakeSocket:
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
def makefile(self, mode="r", buffsize=None):
if "b" in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
return StringIO(self._data.decode("latin_1"))
class FakeServer:
@ -58,14 +58,16 @@ class FakeServer:
self.ssl_target = None
self.unix_target = None
class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.handler = websocketproxy.ProxyRequestHandler(
FakeSocket(), "127.0.0.1", FakeServer())
FakeSocket(), "127.0.0.1", FakeServer()
)
self.handler.path = "https://localhost:6080/websockify?token=blah"
self.handler.headers = {}
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
def tearDown(self):
patch.stopall()
@ -76,8 +78,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token):
return ("some host", "some port")
host, port = self.handler.get_target(
TestPlugin(None))
host, port = self.handler.get_target(TestPlugin(None))
self.assertEqual(host, "some host")
self.assertEqual(port, "some port")
@ -87,8 +88,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token):
return ("unix_socket", "/tmp/socket")
_, socket = self.handler.get_target(
TestPlugin(None))
_, socket = self.handler.get_target(TestPlugin(None))
self.assertEqual(socket, "/tmp/socket")
@ -100,11 +100,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
with self.assertRaises(FakeServer.EClose):
self.handler.get_target(TestPlugin(None))
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_token_plugin(self):
class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token):
return (self.source + token).split(',')
return (self.source + token).split(",")
self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.validate_connection()
@ -112,7 +112,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
@ -128,4 +128,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.server.target_host = "someotherhost"
self.handler.auth_connection()

View File

@ -1,5 +1,5 @@
"""Unit tests for websocketserver"""
""" Unit tests for websocketserver """
import unittest
from unittest.mock import patch, MagicMock
@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase):
# Then
req_obj.end_headers.assert_called_once_with()

View File

@ -14,7 +14,8 @@
# License for the specific language governing permissions and limitations
# under the License.
""" Unit tests for websockifyserver """
"""Unit tests for websockifyserver"""
import errno
import os
import logging
@ -36,11 +37,11 @@ from websockify import websockifyserver
def raise_oserror(*args, **kwargs):
raise OSError('fake error')
raise OSError("fake error")
class FakeSocket:
def __init__(self, data=b''):
def __init__(self, data=b""):
self._data = data
def recv(self, amt, flags=None):
@ -50,19 +51,19 @@ class FakeSocket:
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
def makefile(self, mode="r", buffsize=None):
if "b" in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
return StringIO(self._data.decode("latin_1"))
class WebSockifyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up
patch('os.chdir').start()
patch("os.chdir").start()
def tearDown(self):
"""Called automatically after each test."""
@ -70,31 +71,41 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir)
super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
web = kwargs.pop('web', self.tmpdir)
def _get_server(
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
):
web = kwargs.pop("web", self.tmpdir)
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=web,
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
**kwargs)
handler_class,
listen_host="localhost",
listen_port=80,
key=self.tmpdir,
web=web,
record=self.tmpdir,
daemon=False,
ssl_only=0,
idle_timeout=1,
**kwargs,
)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_normal_get_with_only_upgrade_returns_error(self, send_error):
server = self._get_server(web=None)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
FakeSocket(b"GET /tmp.txt HTTP/1.1"), "127.0.0.1", server
)
handler.do_GET()
send_error.assert_called_with(405)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
def test_list_dir_with_file_only_returns_error(self, send_error):
server = self._get_server(file_only=True)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server)
FakeSocket(b"GET / HTTP/1.1"), "127.0.0.1", server
)
handler.path = '/'
handler.path = "/"
handler.do_GET()
send_error.assert_called_with(404)
@ -102,9 +113,9 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
class WebSockifyServerTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
# Mock this out cause it screws tests up
patch('os.chdir').start()
patch("os.chdir").start()
def tearDown(self):
"""Called automatically after each test."""
@ -112,32 +123,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
os.rmdir(self.tmpdir)
super().tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
def _get_server(
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
):
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=self.tmpdir,
record=self.tmpdir, **kwargs)
handler_class,
listen_host="localhost",
listen_port=80,
key=self.tmpdir,
web=self.tmpdir,
record=self.tmpdir,
**kwargs,
)
def test_daemonize_raises_error_while_closing_fds(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
patch("os.fork").start().return_value = 0
patch("signal.signal").start()
patch("os.setsid").start()
patch("os.close").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
def raise_oserror_ebadf(fd):
raise OSError(errno.EBADF, 'fake error')
raise OSError(errno.EBADF, "fake error")
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror_ebadf
patch('os.open').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
patch("os.fork").start().return_value = 0
patch("signal.signal").start()
patch("os.setsid").start()
patch("os.close").start().side_effect = raise_oserror_ebadf
patch("os.open").start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
def test_handshake_fails_on_not_ready(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -145,23 +162,29 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_select(rlist, wlist, xlist, timeout=None):
return ([], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
FakeSocket(), '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
FakeSocket(),
"127.0.0.1",
)
def test_empty_handshake_fails(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket('')
sock = FakeSocket("")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_handshake_policy_request(self):
# TODO(directxman12): implement
@ -170,35 +193,39 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_handshake_ssl_only_without_ssl_raises_error(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
sock = FakeSocket(b'some initial data')
sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_no_ssl(self):
class FakeHandler:
CALLED = False
def __init__(self, *args, **kwargs):
type(self).CALLED = True
FakeHandler.CALLED = False
server = self._get_server(
handler_class=FakeHandler, daemon=True,
ssl_only=0, idle_timeout=1)
handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1
)
sock = FakeSocket(b'some initial data')
sock = FakeSocket(b"some initial data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
patch("select.select").start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, "127.0.0.1"), sock)
self.assertTrue(FakeHandler.CALLED, True)
def test_do_handshake_ssl(self):
@ -210,18 +237,22 @@ class WebSockifyServerTestCase(unittest.TestCase):
pass
def test_do_handshake_ssl_without_cert_raises_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
cert='afdsfasdafdsafdsafdsafdas')
server = self._get_server(
daemon=True, ssl_only=0, idle_timeout=1, cert="afdsfasdafdsafdsafdsafdas"
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
patch("select.select").start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_error_eof_raises_close_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
@ -234,58 +265,79 @@ class WebSockifyServerTestCase(unittest.TestCase):
def fake_wrap_socket(*args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
class fake_create_default_context():
class fake_create_default_context:
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
websockifyserver.WebSockifyServer.EClose,
server.do_handshake,
sock,
"127.0.0.1",
)
def test_do_handshake_ssl_sets_ciphers(self):
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
test_ciphers = "TEST-CIPHERS-1:TEST-CIPHER-2"
class FakeHandler:
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_ciphers=test_ciphers)
server = self._get_server(
handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_ciphers=test_ciphers,
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
class fake_create_default_context():
CIPHERS = ''
class fake_create_default_context:
CIPHERS = ""
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def set_ciphers(self, ciphers_to_set):
fake_create_default_context.CIPHERS = ciphers_to_set
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
def test_do_handshake_ssl_sets_opions(self):
@ -295,8 +347,12 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_options=test_options)
server = self._get_server(
handler_class=FakeHandler,
daemon=True,
idle_timeout=1,
ssl_options=test_options,
)
sock = FakeSocket(b"\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
@ -304,26 +360,36 @@ class WebSockifyServerTestCase(unittest.TestCase):
class fake_create_default_context:
OPTIONS = 0
def __init__(self, purpose):
self.verify_mode = None
self._options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def get_options(self):
return self._options
def set_options(self, val):
fake_create_default_context.OPTIONS = val
options = property(get_options, set_options)
patch('select.select').start().side_effect = fake_select
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
patch("select.select").start().side_effect = fake_select
patch(
"ssl.create_default_context"
).start().side_effect = fake_create_default_context
server.do_handshake(sock, "127.0.0.1")
self.assertEqual(fake_create_default_context.OPTIONS, test_options)
def test_fallback_sigchld_handler(self):
@ -332,38 +398,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error")
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_start_server_keyboardinterrupt(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_start_server_systemexit(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
sock = server.socket("localhost")
def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit()
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
patch("select.select").start().side_effect = fake_select
server.start_server()
def test_socket_set_keepalive_options(self):
@ -372,29 +438,37 @@ class WebSockifyServerTestCase(unittest.TestCase):
keepintvl = 56
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost',
sock = server.socket(
"localhost",
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
if hasattr(socket, "TCP_KEEPCNT"):
self.assertEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
)
self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)
sock = server.socket('localhost',
sock = server.socket(
"localhost",
tcp_keepalive=False,
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
tcp_keepintvl=keepintvl,
)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
if hasattr(socket, "TCP_KEEPCNT"):
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
)
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle
)
self.assertNotEqual(
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
)

View File

@ -1,4 +1,4 @@
import websockify
if __name__ == '__main__':
if __name__ == "__main__":
websockify.websocketproxy.websockify_init()

View File

@ -1,4 +1,4 @@
class BasePlugin():
class BasePlugin:
def __init__(self, src=None):
self.source = src
@ -7,7 +7,9 @@ class BasePlugin():
class AuthenticationError(Exception):
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
def __init__(
self, log_msg=None, response_code=403, response_headers={}, response_msg=None
):
self.code = response_code
self.headers = response_headers
self.msg = response_msg
@ -15,7 +17,7 @@ class AuthenticationError(Exception):
if log_msg is None:
log_msg = response_msg
super().__init__('%s %s' % (self.code, log_msg))
super().__init__("%s %s" % (self.code, log_msg))
class InvalidOriginError(AuthenticationError):
@ -24,12 +26,13 @@ class InvalidOriginError(AuthenticationError):
self.actual_origin = actual
super().__init__(
response_msg='Invalid Origin',
response_msg="Invalid Origin",
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
"%s, got '%s'" % (expected, actual),
)
class BasicHTTPAuth():
class BasicHTTPAuth:
"""Verifies Basic Auth headers. Specify src as username:password"""
def __init__(self, src=None):
@ -37,9 +40,10 @@ class BasicHTTPAuth():
def authenticate(self, headers, target_host, target_port):
import base64
auth_header = headers.get('Authorization')
auth_header = headers.get("Authorization")
if auth_header:
if not auth_header.startswith('Basic '):
if not auth_header.startswith("Basic "):
self.auth_error()
try:
@ -49,11 +53,11 @@ class BasicHTTPAuth():
try:
# http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication
user_pass_as_text = user_pass_raw.decode('ISO-8859-1')
user_pass_as_text = user_pass_raw.decode("ISO-8859-1")
except UnicodeDecodeError:
self.auth_error()
user_pass = user_pass_as_text.split(':', 1)
user_pass = user_pass_as_text.split(":", 1)
if len(user_pass) != 2:
self.auth_error()
@ -64,7 +68,7 @@ class BasicHTTPAuth():
self.demand_auth()
def validate_creds(self, username, password):
if '%s:%s' % (username, password) == self.src:
if "%s:%s" % (username, password) == self.src:
return True
else:
return False
@ -73,10 +77,13 @@ class BasicHTTPAuth():
raise AuthenticationError(response_code=403)
def demand_auth(self):
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
raise AuthenticationError(
response_code=401,
response_headers={"WWW-Authenticate": 'Basic realm="Websockify"'},
)
class ExpectOrigin():
class ExpectOrigin:
def __init__(self, src=None):
if src is None:
self.source = []
@ -84,11 +91,12 @@ class ExpectOrigin():
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
origin = headers.get('Origin', None)
origin = headers.get("Origin", None)
if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin)
class ClientCertCNAuth():
class ClientCertCNAuth:
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""
def __init__(self, src=None):
@ -98,5 +106,5 @@ class ClientCertCNAuth():
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source:
if headers.get("SSL_CLIENT_S_DN_CN", None) not in self.source:
raise AuthenticationError(response_code=403)

View File

@ -7,23 +7,26 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
as defined by RFC 5424.
"""
_legacy_head_fmt = '<{pri}>{ident}[{pid}]: '
_rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - '
_legacy_head_fmt = "<{pri}>{ident}[{pid}]: "
_rfc5424_head_fmt = "<{pri}>1 {timestamp} {hostname} {ident} {pid} - - "
_head_fmt = _rfc5424_head_fmt
_legacy = False
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ'
_timestamp_fmt = "%Y-%m-%dT%H:%M:%SZ"
_max_hostname = 255
_max_ident = 24 #safer for old daemons
_max_ident = 24 # safer for old daemons
_send_length = False
_tail = '\n'
_tail = "\n"
ident = None
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT),
def __init__(
self,
address=("localhost", handlers.SYSLOG_UDP_PORT),
facility=handlers.SysLogHandler.LOG_USER,
socktype=None, ident=None, legacy=False):
socktype=None,
ident=None,
legacy=False,
):
"""
Initialize a handler.
@ -46,7 +49,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
super().__init__(address, facility, socktype)
def emit(self, record):
"""
Emit a record.
@ -57,46 +59,44 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
try:
# Gather info.
text = self.format(record).replace(self._tail, ' ')
text = self.format(record).replace(self._tail, " ")
if not text: # nothing to log
return
pri = self.encodePriority(self.facility,
self.mapPriority(record.levelname))
pri = self.encodePriority(self.facility, self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime());
hostname = socket.gethostname()[:self._max_hostname]
timestamp = time.strftime(self._timestamp_fmt, time.gmtime())
hostname = socket.gethostname()[: self._max_hostname]
if self.ident:
ident = self.ident[:self._max_ident]
ident = self.ident[: self._max_ident]
else:
ident = ''
ident = ""
pid = os.getpid() # shouldn't need truncation
# Format the header.
head = {
'pri': pri,
'timestamp': timestamp,
'hostname': hostname,
'ident': ident,
'pid': pid,
"pri": pri,
"timestamp": timestamp,
"hostname": hostname,
"ident": ident,
"pid": pid,
}
msg = self._head_fmt.format(**head).encode('ascii', 'ignore')
msg = self._head_fmt.format(**head).encode("ascii", "ignore")
# Encode text as plain ASCII if possible, else use UTF-8 with BOM.
try:
msg += text.encode('ascii')
msg += text.encode("ascii")
except UnicodeEncodeError:
msg += text.encode('utf-8-sig')
msg += text.encode("utf-8-sig")
# Add length or tail character, if necessary.
if self.socktype != socket.SOCK_DGRAM:
if self._send_length:
msg = ('%d ' % len(msg)).encode('ascii') + msg
msg = ("%d " % len(msg)).encode("ascii") + msg
else:
msg += self._tail.encode('ascii')
msg += self._tail.encode("ascii")
# Send the message.
if self.unixsocket:

View File

@ -10,8 +10,8 @@ logger = logging.getLogger(__name__)
_SOURCE_SPLIT_REGEX = re.compile(
r'(?<=^)"([^"]+)"(?=:|$)'
r'|(?<=:)"([^"]+)"(?=:|$)'
r'|(?<=^)([^:]*)(?=:|$)'
r'|(?<=:)([^:]*)(?=:|$)',
r"|(?<=^)([^:]*)(?=:|$)"
r"|(?<=:)([^:]*)(?=:|$)",
)
@ -26,7 +26,7 @@ def parse_source_args(src):
return [m[0] or m[1] or m[2] or m[3] for m in matches]
class BasePlugin():
class BasePlugin:
def __init__(self, src):
self.source = src
@ -44,8 +44,7 @@ class ReadOnlyTokenFile(BasePlugin):
def _load_targets(self):
if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for
f in os.listdir(self.source)]
cfg_files = [os.path.join(self.source, f) for f in os.listdir(self.source)]
else:
cfg_files = [self.source]
@ -53,12 +52,14 @@ class ReadOnlyTokenFile(BasePlugin):
index = 1
for f in cfg_files:
for line in [l.strip() for l in open(f).readlines()]:
if line and not line.startswith('#'):
if line and not line.startswith("#"):
try:
tok, target = re.split(r':\s', line)
self._targets[tok] = target.strip().rsplit(':', 1)
tok, target = re.split(r":\s", line)
self._targets[tok] = target.strip().rsplit(":", 1)
except ValueError:
logger.error("Syntax error in %s on line %d" % (self.source, index))
logger.error(
"Syntax error in %s on line %d" % (self.source, index)
)
index += 1
def lookup(self, token):
@ -83,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile):
return super().lookup(token)
class TokenFileName(BasePlugin):
# source is a directory
# token is filename
@ -96,7 +98,7 @@ class TokenFileName(BasePlugin):
token = os.path.basename(token)
path = os.path.join(self.source, token)
if os.path.exists(path):
return open(path).read().strip().split(':')
return open(path).read().strip().split(":")
else:
return None
@ -109,9 +111,9 @@ class BaseTokenAPI(BasePlugin):
# in this file can be used w/o unnecessary dependencies
def process_result(self, resp):
host, port = resp.text.split(':')
port = port.encode('ascii','ignore')
return [ host, port ]
host, port = resp.text.split(":")
port = port.encode("ascii", "ignore")
return [host, port]
def lookup(self, token):
import requests
@ -130,7 +132,7 @@ class JSONTokenApi(BaseTokenAPI):
def process_result(self, resp):
resp_json = resp.json()
return (resp_json['host'], resp_json['port'])
return (resp_json["host"], resp_json["port"])
class JWTTokenApi(BasePlugin):
@ -145,7 +147,7 @@ class JWTTokenApi(BasePlugin):
key = jwk.JWK()
try:
with open(self.source, 'rb') as key_file:
with open(self.source, "rb") as key_file:
key_data = key_file.read()
except Exception as e:
logger.error("Error loading key file: %s" % str(e))
@ -155,39 +157,41 @@ class JWTTokenApi(BasePlugin):
key.import_from_pem(key_data)
except:
try:
key.import_key(k=key_data.decode('utf-8'),kty='oct')
key.import_key(k=key_data.decode("utf-8"), kty="oct")
except:
logger.error('Failed to correctly parse key data!')
logger.error("Failed to correctly parse key data!")
return None
try:
token = jwt.JWT(key=key, jwt=token)
parsed_header = json.loads(token.header)
if 'enc' in parsed_header:
if "enc" in parsed_header:
# Token is encrypted, so we need to decrypt by passing the claims to a new instance
token = jwt.JWT(key=key, jwt=token.claims)
parsed = json.loads(token.claims)
if 'nbf' in parsed:
if "nbf" in parsed:
# Not Before is present, so we need to check it
if time.time() < parsed['nbf']:
logger.warning('Token can not be used yet!')
if time.time() < parsed["nbf"]:
logger.warning("Token can not be used yet!")
return None
if 'exp' in parsed:
if "exp" in parsed:
# Expiration time is present, so we need to check it
if time.time() > parsed['exp']:
logger.warning('Token has expired!')
if time.time() > parsed["exp"]:
logger.warning("Token has expired!")
return None
return (parsed['host'], parsed['port'])
return (parsed["host"], parsed["port"])
except Exception as e:
logger.error("Failed to parse token: %s" % str(e))
return None
except ImportError:
logger.error("package jwcrypto not found, are you sure you've installed it correctly?")
logger.error(
"package jwcrypto not found, are you sure you've installed it correctly?"
)
return None
@ -251,6 +255,7 @@ class TokenRedis(BasePlugin):
pip install redis
"""
def __init__(self, src):
try:
import redis
@ -285,7 +290,9 @@ class TokenRedis(BasePlugin):
if not self._password:
self._password = None
elif len(fields) == 5:
self._server, self._port, self._db, self._password, self._namespace = fields
self._server, self._port, self._db, self._password, self._namespace = (
fields
)
if not self._port:
self._port = 6379
if not self._db:
@ -301,24 +308,30 @@ class TokenRedis(BasePlugin):
if self._namespace:
self._namespace += ":"
logger.info("TokenRedis backend initialized (%s:%s)" %
(self._server, self._port))
logger.info(
"TokenRedis backend initialized (%s:%s)" % (self._server, self._port)
)
except ValueError:
logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
src)
logger.error(
"The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]"
% src
)
sys.exit()
def lookup(self, token):
try:
import redis
except ImportError:
logger.error("package redis not found, are you sure you've installed them correctly?")
logger.error(
"package redis not found, are you sure you've installed them correctly?"
)
sys.exit()
logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port,
db=self._db, password=self._password)
client = redis.Redis(
host=self._server, port=self._port, db=self._db, password=self._password
)
stuff = client.get(self._namespace + token)
if stuff is None:
return None
@ -330,14 +343,14 @@ class TokenRedis(BasePlugin):
combo = json.loads(responseStr)
host, port = combo["host"].split(":")
except ValueError:
logger.error("Unable to decode JSON token: %s" %
responseStr)
logger.error("Unable to decode JSON token: %s" % responseStr)
return None
except KeyError:
logger.error("Unable to find 'host' key in JSON token: %s" %
responseStr)
logger.error(
"Unable to find 'host' key in JSON token: %s" % responseStr
)
return None
elif re.match(r'\S+:\S+', responseStr):
elif re.match(r"\S+:\S+", responseStr):
host, port = responseStr.split(":")
else:
logger.error("Unable to parse token: %s" % responseStr)
@ -368,7 +381,7 @@ class UnixDomainSocketDirectory(BasePlugin):
if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
return None
return [ 'unix_socket', uds_path ]
return ["unix_socket", uds_path]
except Exception as e:
logger.error("Error finding unix domain socket: %s" % str(e))
return None

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
Python WebSocket library
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
@ -10,7 +10,7 @@ Supports following protocol versions:
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
- http://tools.ietf.org/html/rfc6455
'''
"""
import sys
import array
@ -28,14 +28,19 @@ try:
import numpy
except ImportError:
import warnings
warnings.warn("no 'numpy' module, HyBi protocol will be slower")
numpy = None
class WebSocketWantReadError(ssl.SSLWantReadError):
pass
class WebSocketWantWriteError(ssl.SSLWantWriteError):
pass
class WebSocket:
"""WebSocket protocol socket like class.
@ -73,11 +78,11 @@ class WebSocket:
self._state = "new"
self._partial_msg = b''
self._partial_msg = b""
self._recv_buffer = b''
self._recv_buffer = b""
self._recv_queue = []
self._send_buffer = b''
self._send_buffer = b""
self._previous_sendmsg = None
@ -91,16 +96,22 @@ class WebSocket:
def __getattr__(self, name):
# These methods are just redirected to the underlying socket
if name in ["fileno",
"getpeername", "getsockname",
"getsockopt", "setsockopt",
"gettimeout", "settimeout",
"setblocking"]:
if name in [
"fileno",
"getpeername",
"getsockname",
"getsockopt",
"setsockopt",
"gettimeout",
"settimeout",
"setblocking",
]:
assert self.socket is not None
return getattr(self.socket, name)
else:
raise AttributeError("%s instance has no attribute '%s'" %
(self.__class__.__name__, name))
raise AttributeError(
"%s instance has no attribute '%s'" % (self.__class__.__name__, name)
)
def connect(self, uri, origin=None, protocols=[]):
"""Establishes a new connection to a WebSocket server.
@ -118,8 +129,7 @@ class WebSocket:
connect() must retain the same arguments.
"""
self.client = True;
self.client = True
uri = urlparse(uri)
port = uri.port
@ -140,8 +150,9 @@ class WebSocket:
if uri.scheme in ("wss", "https"):
context = ssl.create_default_context()
self.socket = context.wrap_socket(self.socket,
server_hostname=uri.hostname)
self.socket = context.wrap_socket(
self.socket, server_hostname=uri.hostname
)
self._state = "ssl_handshake"
else:
self._state = "headers"
@ -151,7 +162,7 @@ class WebSocket:
self._state = "headers"
if self._state == "headers":
self._key = ''
self._key = ""
for i in range(16):
self._key += chr(random.randrange(256))
self._key = b64encode(self._key.encode("latin-1")).decode("ascii")
@ -184,10 +195,10 @@ class WebSocket:
if not self._recv():
raise Exception("Socket closed unexpectedly")
if self._recv_buffer.find(b'\r\n\r\n') == -1:
if self._recv_buffer.find(b"\r\n\r\n") == -1:
raise WebSocketWantReadError
(request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1)
(request, self._recv_buffer) = self._recv_buffer.split(b"\r\n", 1)
request = request.decode("latin-1")
words = request.split()
@ -196,17 +207,17 @@ class WebSocket:
if words[1] != "101":
raise Exception("WebSocket request denied: %s" % " ".join(words[1:]))
(headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1)
headers = headers.decode('latin-1') + '\r\n'
(headers, self._recv_buffer) = self._recv_buffer.split(b"\r\n\r\n", 1)
headers = headers.decode("latin-1") + "\r\n"
headers = email.message_from_string(headers)
if headers.get("Upgrade", "").lower() != "websocket":
print(type(headers))
raise Exception("Missing or incorrect upgrade header")
accept = headers.get('Sec-WebSocket-Accept')
accept = headers.get("Sec-WebSocket-Accept")
if accept is None:
raise Exception("Missing Sec-WebSocket-Accept header");
raise Exception("Missing Sec-WebSocket-Accept header")
expected = sha1((self._key + self.GUID).encode("ascii")).digest()
expected = b64encode(expected).decode("ascii")
@ -214,9 +225,9 @@ class WebSocket:
del self._key
if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header");
raise Exception("Invalid Sec-WebSocket-Accept header")
self.protocol = headers.get('Sec-WebSocket-Protocol')
self.protocol = headers.get("Sec-WebSocket-Protocol")
if len(protocols) == 0:
if self.protocol is not None:
raise Exception("Unexpected Sec-WebSocket-Protocol header")
@ -256,34 +267,34 @@ class WebSocket:
if headers.get("upgrade", "").lower() != "websocket":
raise Exception("Missing or incorrect upgrade header")
ver = headers.get('Sec-WebSocket-Version')
ver = headers.get("Sec-WebSocket-Version")
if ver is None:
raise Exception("Missing Sec-WebSocket-Version header");
raise Exception("Missing Sec-WebSocket-Version header")
# HyBi-07 report version 7
# HyBi-08 - HyBi-12 report version 8
# HyBi-13 reports version 13
if ver in ['7', '8', '13']:
if ver in ["7", "8", "13"]:
self.version = "hybi-%02d" % int(ver)
else:
raise Exception("Unsupported protocol version %s" % ver)
key = headers.get('Sec-WebSocket-Key')
key = headers.get("Sec-WebSocket-Key")
if key is None:
raise Exception("Missing Sec-WebSocket-Key header");
raise Exception("Missing Sec-WebSocket-Key header")
# Generate the hash value for the accept header
accept = sha1((key + self.GUID).encode("ascii")).digest()
accept = b64encode(accept).decode("ascii")
self.protocol = ''
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
self.protocol = ""
protocols = headers.get("Sec-WebSocket-Protocol", "").split(",")
if protocols:
self.protocol = self.select_subprotocol(protocols)
# We are required to choose one of the protocols
# presented by the client
if self.protocol not in protocols:
raise Exception('Invalid protocol selected')
raise Exception("Invalid protocol selected")
self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
@ -461,7 +472,7 @@ class WebSocket:
def send_request(self, type, path):
self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path))
def ping(self, data=b''):
def ping(self, data=b""):
"""Write a ping message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
@ -486,7 +497,7 @@ class WebSocket:
self._previous_sendmsg = data
raise
def pong(self, data=b''):
def pong(self, data=b""):
"""Write a pong message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
@ -540,7 +551,7 @@ class WebSocket:
self._sent_close = True
msg = b''
msg = b""
if code is not None:
msg += struct.pack(">H", code)
if reason is not None:
@ -602,7 +613,7 @@ class WebSocket:
frame = self._decode_hybi(self._recv_buffer)
if frame is None:
break
self._recv_buffer = self._recv_buffer[frame['length']:]
self._recv_buffer = self._recv_buffer[frame["length"] :]
self._recv_queue.append(frame)
return True
@ -612,29 +623,39 @@ class WebSocket:
while self._recv_queue:
frame = self._recv_queue.pop(0)
if not self.client and not frame['masked']:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked")
if not self.client and not frame["masked"]:
self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked"
)
continue
if self.client and frame['masked']:
if self.client and frame["masked"]:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked")
continue
if frame["opcode"] == 0x0:
if not self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame")
self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Unexpected continuation frame",
)
continue
self._partial_msg += frame["payload"]
if frame["fin"]:
msg = self._partial_msg
self._partial_msg = b''
self._partial_msg = b""
return msg
elif frame["opcode"] == 0x1:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported"
)
elif frame["opcode"] == 0x2:
if self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame")
self.shutdown(
socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame"
)
continue
if frame["fin"]:
@ -652,7 +673,9 @@ class WebSocket:
return None
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close"
)
continue
code = None
@ -664,7 +687,11 @@ class WebSocket:
try:
reason = reason.decode("UTF-8")
except UnicodeDecodeError:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close")
self.shutdown(
socket.SHUT_RDWR,
1002,
"Procotol error: Invalid UTF-8 in close",
)
continue
if code is None:
@ -679,18 +706,26 @@ class WebSocket:
return None
elif frame["opcode"] == 0x9:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping"
)
continue
self.handle_ping(frame["payload"])
elif frame["opcode"] == 0xA:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong")
self.shutdown(
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong"
)
continue
self.handle_pong(frame["payload"])
else:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"])
self.shutdown(
socket.SHUT_RDWR,
1003,
"Unsupported: Unknown opcode 0x%02x" % frame["opcode"],
)
raise WebSocketWantReadError
@ -731,7 +766,7 @@ class WebSocket:
def _sendmsg(self, opcode, msg):
# Sends a standard data message
if self.client:
mask = b''
mask = b""
for i in range(4):
mask += random.randrange(256).to_bytes()
frame = self._encode_hybi(opcode, msg, mask)
@ -755,35 +790,36 @@ class WebSocket:
plen = len(buf)
pstart = 0
pend = plen
b = c = b''
b = c = b""
if plen >= 4:
dtype=numpy.dtype('<u4')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
dtype = numpy.dtype("<u4")
if sys.byteorder == "big":
dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=1)
data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
#b = numpy.bitwise_xor(data, mask).data
# b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tobytes()
if plen % 4:
dtype=numpy.dtype('B')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
dtype = numpy.dtype("B")
if sys.byteorder == "big":
dtype = dtype.newbyteorder(">")
mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
data = numpy.frombuffer(buf, dtype,
offset=plen - (plen % 4), count=(plen % 4))
data = numpy.frombuffer(
buf, dtype, offset=plen - (plen % 4), count=(plen % 4)
)
c = numpy.bitwise_xor(data, mask).tobytes()
return b + c
else:
# Slower fallback
data = array.array('B')
data = array.array("B")
data.frombytes(buf)
for i in range(len(data)):
data[i] ^= mask[i % 4]
return data.tobytes()
def _encode_hybi(self, opcode, buf, mask_key=None, fin=True):
""" Encode a HyBi style WebSocket frame.
"""Encode a HyBi style WebSocket frame.
Optional opcode:
0x0 - continuation
0x1 - text frame
@ -793,7 +829,7 @@ class WebSocket:
0xA - pong
"""
b1 = opcode & 0x0f
b1 = opcode & 0x0F
if fin:
b1 |= 0x80
@ -804,11 +840,11 @@ class WebSocket:
payload_len = len(buf)
if payload_len <= 125:
header = struct.pack('>BB', b1, payload_len | mask_bit)
header = struct.pack(">BB", b1, payload_len | mask_bit)
elif payload_len > 125 and payload_len < 65536:
header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len)
header = struct.pack(">BBH", b1, 126 | mask_bit, payload_len)
elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len)
header = struct.pack(">BBQ", b1, 127 | mask_bit, payload_len)
if mask_key is not None:
return header + mask_key + buf
@ -816,7 +852,7 @@ class WebSocket:
return header + buf
def _decode_hybi(self, buf):
""" Decode HyBi style WebSocket packets.
"""Decode HyBi style WebSocket packets.
Returns:
{'fin' : boolean,
'opcode' : number,
@ -825,11 +861,7 @@ class WebSocket:
'payload' : decoded_buffer}
"""
f = {'fin' : 0,
'opcode' : 0,
'masked' : False,
'length' : 0,
'payload' : None}
f = {"fin": 0, "opcode": 0, "masked": False, "length": 0, "payload": None}
blen = len(buf)
hlen = 2
@ -838,39 +870,38 @@ class WebSocket:
return None
b1, b2 = struct.unpack(">BB", buf[:2])
f['opcode'] = b1 & 0x0f
f['fin'] = not not (b1 & 0x80)
f['masked'] = not not (b2 & 0x80)
f["opcode"] = b1 & 0x0F
f["fin"] = not not (b1 & 0x80)
f["masked"] = not not (b2 & 0x80)
if f['masked']:
if f["masked"]:
hlen += 4
if blen < hlen:
return None
length = b2 & 0x7f
length = b2 & 0x7F
if length == 126:
hlen += 2
if blen < hlen:
return None
length, = struct.unpack('>H', buf[2:4])
(length,) = struct.unpack(">H", buf[2:4])
elif length == 127:
hlen += 8
if blen < hlen:
return None
length, = struct.unpack('>Q', buf[2:10])
(length,) = struct.unpack(">Q", buf[2:10])
f['length'] = hlen + length
f["length"] = hlen + length
if blen < f['length']:
if blen < f["length"]:
return None
if f['masked']:
if f["masked"]:
# unmask payload
mask_key = buf[hlen-4:hlen]
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key)
mask_key = buf[hlen - 4 : hlen]
f["payload"] = self._unmask(buf[hlen : (hlen + length)], mask_key)
else:
f['payload'] = buf[hlen:(hlen+length)]
f["payload"] = buf[hlen : (hlen + length)]
return f

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
A WebSocket to TCP socket proxy with support for "wss://" encryption.
Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
@ -9,7 +9,7 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn
@ -20,8 +20,8 @@ from websockify import websockifyserver
from websockify import auth_plugins as auth
from urllib.parse import parse_qs, urlparse
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
buffer_size = 65536
traffic_legend = """
@ -38,7 +38,7 @@ Traffic Legend:
def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
self.send_header("Content-Type", "text/html")
for name, val in ex.headers.items():
self.send_header(name, val)
@ -49,7 +49,7 @@ Traffic Legend:
return
host, port = self.get_target(self.server.token_plugin)
if host == 'unix_socket':
if host == "unix_socket":
self.server.unix_target = port
else:
@ -62,7 +62,7 @@ Traffic Legend:
# clear out any existing SSL_ headers that the client might
# have maliciously set
ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ]
ssl_headers = [h for h in self.headers if h.startswith("SSL_")]
for h in ssl_headers:
del self.headers[h]
@ -70,19 +70,21 @@ Traffic Legend:
# get client certificate data
client_cert_data = self.request.getpeercert()
# extract subject information
client_cert_subject = client_cert_data['subject']
client_cert_subject = client_cert_data["subject"]
# flatten data structure
client_cert_subject = dict([x[0] for x in client_cert_subject])
# add common name to headers (apache +StdEnvVars style)
self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName']
self.headers["SSL_CLIENT_S_DN_CN"] = client_cert_subject["commonName"]
except (TypeError, AttributeError, KeyError):
# not a SSL connection or client presented no certificate with valid data
pass
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
headers=self.headers,
target_host=self.server.target_host,
target_port=self.server.target_port,
)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
@ -96,26 +98,37 @@ Traffic Legend:
# Connect to the target
if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
msg = "connecting to command: '%s' (port %s)" % (
" ".join(self.server.wrap_cmd),
self.server.target_port,
)
elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.server.unix_target
else:
msg = "connecting to: %s:%s" % (
self.server.target_host, self.server.target_port)
self.server.target_host,
self.server.target_port,
)
if self.server.ssl_target:
msg += " (using SSL)"
self.log_message(msg)
try:
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host,
tsock = websockifyserver.WebSockifyServer.socket(
self.server.target_host,
self.server.target_port,
connect=True,
use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target)
unix_socket=self.server.unix_target,
)
except Exception as e:
self.log_message("Failed to connect to %s:%s: %s",
self.server.target_host, self.server.target_port, e)
self.log_message(
"Failed to connect to %s:%s: %s",
self.server.target_host,
self.server.target_port,
e,
)
raise self.CClose(1011, "Failed to connect to downstream server")
# Option unavailable when listening to unix socket
@ -134,8 +147,11 @@ Traffic Legend:
tsock.shutdown(socket.SHUT_RDWR)
tsock.close()
if self.verbose:
self.log_message("%s:%s: Closed target",
self.server.target_host, self.server.target_port)
self.log_message(
"%s:%s: Closed target",
self.server.target_host,
self.server.target_port,
)
def get_target(self, target_plugin):
"""
@ -149,20 +165,20 @@ Traffic Legend:
if self.host_token:
# Use hostname as token
token = self.headers.get('Host')
token = self.headers.get("Host")
# Remove port from hostname, as it'll always be the one where
# websockify listens (unless something between the client and
# websockify is redirecting traffic, but that's beside the point)
if token:
token = token.partition(':')[0]
token = token.partition(":")[0]
else:
# Extract the token parameter from url
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
if 'token' in args and len(args['token']):
token = args['token'][0].rstrip('\n')
if "token" in args and len(args["token"]):
token = args["token"][0].rstrip("\n")
else:
token = None
@ -200,13 +216,15 @@ Traffic Legend:
self.heartbeat = now + self.server.heartbeat
self.send_ping()
if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.request)
if tqueue:
wlist.append(target)
if cqueue or c_pend:
wlist.append(self.request)
try:
ins, outs, excepts = select.select(rlist, wlist, [], 1)
except OSError:
exc = sys.exc_info()[1]
if hasattr(exc, 'errno'):
if hasattr(exc, "errno"):
err = exc.errno
else:
err = exc[0]
@ -216,7 +234,8 @@ Traffic Legend:
else:
continue
if excepts: raise Exception("Socket exception")
if excepts:
raise Exception("Socket exception")
if self.request in outs:
# Send queued target data to the client
@ -230,8 +249,7 @@ Traffic Legend:
tqueue.extend(bufs)
if closed:
while (len(tqueue) != 0):
while len(tqueue) != 0:
# Send queued client data to the target
dat = tqueue.pop(0)
sent = target.send(dat)
@ -244,10 +262,12 @@ Traffic Legend:
# TODO: What about blocking on client socket?
if self.verbose:
self.log_message("%s:%s: Client closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(closed['code'], closed['reason'])
self.log_message(
"%s:%s: Client closed connection",
self.server.target_host,
self.server.target_port,
)
raise self.CClose(closed["code"], closed["reason"])
if target in outs:
# Send queued client data to the target
@ -260,29 +280,31 @@ Traffic Legend:
tqueue.insert(0, dat[sent:])
self.print_traffic(".>")
if target in ins:
# Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size)
if len(buf) == 0:
# Target socket closed, flushing queues and closing client-side websocket
# Send queued target data to the client
if len(cqueue) != 0:
c_pend = True
while(c_pend):
while c_pend:
c_pend = self.send_frames(cqueue)
cqueue = []
if self.verbose:
self.log_message("%s:%s: Target closed connection",
self.server.target_host, self.server.target_port)
self.log_message(
"%s:%s: Target closed connection",
self.server.target_host,
self.server.target_port,
)
raise self.CClose(1000, "Target closed")
cqueue.append(buf)
self.print_traffic("{")
class WebSocketProxy(websockifyserver.WebSockifyServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
@ -293,27 +315,29 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop("ssl_target", None)
self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.host_token = kwargs.pop('host_token', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.token_plugin = kwargs.pop("token_plugin", None)
self.host_token = kwargs.pop("host_token", None)
self.auth_plugin = kwargs.pop("auth_plugin", None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0])
rebinder_path = [os.path.join(wsdir, "..", "lib"),
rebinder_path = [
os.path.join(wsdir, "..", "lib"),
os.path.join(wsdir, "..", "lib", "websockify"),
os.path.join(wsdir, ".."),
wsdir]
wsdir,
]
self.rebinder = None
for rdir in rebinder_path:
@ -329,17 +353,22 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.target_host = "127.0.0.1" # Loopback
# Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
sock.bind(("", 0))
self.target_port = sock.getsockname()[1]
sock.close()
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ])
ld_preloads = filter(
None, [self.rebinder, os.environ.get("LD_PRELOAD", None)]
)
os.environ.update({
os.environ.update(
{
"LD_PRELOAD": os.pathsep.join(ld_preloads),
"REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)})
"REBIND_OLD_PORT": str(kwargs["listen_port"]),
"REBIND_NEW_PORT": str(self.target_port),
}
)
super().__init__(RequestHandlerClass, *args, **kwargs)
@ -348,7 +377,8 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.wrap_times.append(time.time())
self.wrap_times.pop(0)
self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup
)
self.spawn_message = True
def started(self):
@ -371,10 +401,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
if self.token_plugin:
msg = " - proxying from %s to targets generated by %s" % (
src_string, type(self.token_plugin).__name__)
src_string,
type(self.token_plugin).__name__,
)
else:
msg = " - proxying from %s to %s" % (
src_string, dst_string)
msg = " - proxying from %s to %s" % (src_string, dst_string)
if self.ssl_target:
msg += " (using SSL)"
@ -401,7 +432,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
sys.exit(ret)
elif self.wrap_mode == "respawn":
now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times)
avg = sum(self.wrap_times) / len(self.wrap_times)
if (now - avg) < 10:
# 3 times in the last 10 seconds
if self.spawn_message:
@ -418,15 +449,25 @@ def _subprocess_setup():
SSL_OPTIONS = {
'default': ssl.OP_ALL,
'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1,
'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1,
'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2,
"default": ssl.OP_ALL,
"tlsv1_1": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1,
"tlsv1_2": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1,
"tlsv1_3": ssl.PROTOCOL_SSLv23
| ssl.OP_NO_SSLv2
| ssl.OP_NO_SSLv3
| ssl.OP_NO_TLSv1
| ssl.OP_NO_TLSv1_1
| ssl.OP_NO_TLSv1_2,
}
def select_ssl_version(version):
"""Returns SSL options for the most secure TSL version available on this
Python version"""
@ -439,11 +480,11 @@ def select_ssl_version(version):
keys.sort()
fallback = keys[-1]
logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.warn("TLS version %s unsupported. Falling back to %s",
version, fallback)
logger.warn("TLS version %s unsupported. Falling back to %s", version, fallback)
return SSL_OPTIONS[fallback]
def websockify_init():
# Setup basic logging to stderr.
stderr_handler = logging.StreamHandler()
@ -464,105 +505,198 @@ def websockify_init():
usage += "\n %prog [options]"
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages")
parser.add_option("--traffic", action="store_true",
help="per frame traffic")
parser.add_option("--record",
help="record sessions to FILE.[session_number]", metavar="FILE")
parser.add_option("--daemon", "-D",
dest="daemon", action="store_true",
help="become a daemon (background process)")
parser.add_option("--run-once", action="store_true",
help="handle a single WebSocket connection and exit")
parser.add_option("--timeout", type=int, default=0,
help="after TIMEOUT seconds exit when not connected")
parser.add_option("--idle-timeout", type=int, default=0,
help="server exits after TIMEOUT seconds if there are no "
"active connections")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--key-password", default=None,
help="SSL key password")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted client connections")
parser.add_option("--ssl-target", action="store_true",
help="connect to SSL target as SSL client")
parser.add_option("--verify-client", action="store_true",
parser.add_option("--verbose", "-v", action="store_true", help="verbose messages")
parser.add_option("--traffic", action="store_true", help="per frame traffic")
parser.add_option(
"--record", help="record sessions to FILE.[session_number]", metavar="FILE"
)
parser.add_option(
"--daemon",
"-D",
dest="daemon",
action="store_true",
help="become a daemon (background process)",
)
parser.add_option(
"--run-once",
action="store_true",
help="handle a single WebSocket connection and exit",
)
parser.add_option(
"--timeout",
type=int,
default=0,
help="after TIMEOUT seconds exit when not connected",
)
parser.add_option(
"--idle-timeout",
type=int,
default=0,
help="server exits after TIMEOUT seconds if there are no active connections",
)
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
parser.add_option(
"--key", default=None, help="SSL key file (if separate from cert)"
)
parser.add_option("--key-password", default=None, help="SSL key password")
parser.add_option(
"--ssl-only",
action="store_true",
help="disallow non-encrypted client connections",
)
parser.add_option(
"--ssl-target", action="store_true", help="connect to SSL target as SSL client"
)
parser.add_option(
"--verify-client",
action="store_true",
help="require encrypted client to present a valid certificate "
"(needs Python 2.7.9 or newer or Python 3.4 or newer)")
parser.add_option("--cafile", metavar="FILE",
"(needs Python 2.7.9 or newer or Python 3.4 or newer)",
)
parser.add_option(
"--cafile",
metavar="FILE",
help="file of concatenated certificates of authorities trusted "
"for validating clients (only effective with --verify-client). "
"If omitted, system default list of CAs is used.")
parser.add_option("--ssl-version", type="choice", default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)")
parser.add_option("--ssl-ciphers", action="store",
"If omitted, system default list of CAs is used.",
)
parser.add_option(
"--ssl-version",
type="choice",
default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"],
action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)",
)
parser.add_option(
"--ssl-ciphers",
action="store",
help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd",
help="inetd mode, receive listening socket from stdin", action="store_true")
parser.add_option("--web", default=None, metavar="DIR",
help="run webserver on same port. Serve files from DIR.")
parser.add_option("--web-auth", action="store_true",
help="require authentication to access webserver.")
parser.add_option("--wrap-mode", default="exit", metavar="MODE",
"supported ciphers run `openssl ciphers`",
)
parser.add_option(
"--unix-listen", help="listen to unix socket", metavar="FILE", default=None
)
parser.add_option(
"--unix-listen-mode",
default=None,
help="specify mode for unix socket (defaults to 0600)",
)
parser.add_option(
"--unix-target", help="connect to unix socket target", metavar="FILE"
)
parser.add_option(
"--inetd",
help="inetd mode, receive listening socket from stdin",
action="store_true",
)
parser.add_option(
"--web",
default=None,
metavar="DIR",
help="run webserver on same port. Serve files from DIR.",
)
parser.add_option(
"--web-auth",
action="store_true",
help="require authentication to access webserver.",
)
parser.add_option(
"--wrap-mode",
default="exit",
metavar="MODE",
choices=["exit", "ignore", "respawn"],
help="action to take when the wrapped program exits "
"or daemonizes: exit (default), ignore, respawn")
parser.add_option("--prefer-ipv6", "-6",
action="store_true", dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
parser.add_option("--target-config", metavar="FILE",
"or daemonizes: exit (default), ignore, respawn",
)
parser.add_option(
"--prefer-ipv6",
"-6",
action="store_true",
dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr",
)
parser.add_option(
"--libserver",
action="store_true",
help="use Python library SocketServer engine",
)
parser.add_option(
"--target-config",
metavar="FILE",
dest="target_cfg",
help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form "
"(DEPRECATED: use `--token-plugin TokenFile --token-source "
" path/to/token/file` instead)")
parser.add_option("--token-plugin", default=None, metavar="CLASS",
" path/to/token/file` instead)",
)
parser.add_option(
"--token-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.token_plugins, "
"such as TokenFile, to process tokens into host:port pairs")
parser.add_option("--token-source", default=None, metavar="ARG",
help="an argument to be passed to the token plugin "
"on instantiation")
parser.add_option("--host-token", action="store_true",
"such as TokenFile, to process tokens into host:port pairs",
)
parser.add_option(
"--token-source",
default=None,
metavar="ARG",
help="an argument to be passed to the token plugin on instantiation",
)
parser.add_option(
"--host-token",
action="store_true",
help="use the host HTTP header as token instead of the "
"token URL query parameter")
parser.add_option("--auth-plugin", default=None, metavar="CLASS",
"token URL query parameter",
)
parser.add_option(
"--auth-plugin",
default=None,
metavar="CLASS",
help="use a Python class, usually one from websockify.auth_plugins, "
"such as BasicHTTPAuth, to determine if a connection is allowed")
parser.add_option("--auth-source", default=None, metavar="ARG",
help="an argument to be passed to the auth plugin "
"on instantiation")
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds")
parser.add_option("--log-file", metavar="FILE",
"such as BasicHTTPAuth, to determine if a connection is allowed",
)
parser.add_option(
"--auth-source",
default=None,
metavar="ARG",
help="an argument to be passed to the auth plugin on instantiation",
)
parser.add_option(
"--heartbeat",
type=int,
default=0,
metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds",
)
parser.add_option(
"--log-file",
metavar="FILE",
dest="log_file",
help="File where logs will be saved")
parser.add_option("--syslog", default=None, metavar="SERVER",
help="File where logs will be saved",
)
parser.add_option(
"--syslog",
default=None,
metavar="SERVER",
help="Log to syslog server. SERVER can be local socket, "
"such as /dev/log, or a UDP host:port pair.")
parser.add_option("--legacy-syslog", action="store_true",
"such as /dev/log, or a UDP host:port pair.",
)
parser.add_option(
"--legacy-syslog",
action="store_true",
help="Use the old syslog protocol instead of RFC 5424. "
"Use this if the messages produced by websockify seem abnormal.")
parser.add_option("--file-only", action="store_true",
help="use this to disable directory listings in web server.")
"Use this if the messages produced by websockify seem abnormal.",
)
parser.add_option(
"--file-only",
action="store_true",
help="use this to disable directory listings in web server.",
)
(opts, args) = parser.parse_args()
# Validate options.
if opts.token_source and not opts.token_plugin:
@ -583,11 +717,9 @@ def websockify_init():
if opts.legacy_syslog and not opts.syslog:
parser.error("You must use --syslog to use --legacy-syslog")
opts.ssl_options = select_ssl_version(opts.ssl_version)
del opts.ssl_version
if opts.log_file:
# Setup logging to user-specified file.
opts.log_file = os.path.abspath(opts.log_file)
@ -601,9 +733,9 @@ def websockify_init():
if opts.syslog:
# Determine how to connect to syslog...
if opts.syslog.count(':'):
if opts.syslog.count(":"):
# User supplied a host:port pair.
syslog_host, syslog_port = opts.syslog.rsplit(':', 1)
syslog_host, syslog_port = opts.syslog.rsplit(":", 1)
try:
syslog_port = int(syslog_port)
except ValueError:
@ -622,10 +754,12 @@ def websockify_init():
syslog_facility = WebsockifySysLogHandler.LOG_USER
# Start logging to syslog.
syslog_handler = WebsockifySysLogHandler(address=syslog_dest,
syslog_handler = WebsockifySysLogHandler(
address=syslog_dest,
facility=syslog_facility,
ident='websockify',
legacy=opts.legacy_syslog)
ident="websockify",
legacy=opts.legacy_syslog,
)
syslog_handler.setLevel(logging.DEBUG)
syslog_handler.setFormatter(log_formatter)
root = logging.getLogger()
@ -638,24 +772,23 @@ def websockify_init():
root = logging.getLogger()
root.setLevel(logging.DEBUG)
# Transform to absolute path as daemon may chdir
if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg)
if opts.target_cfg:
opts.token_plugin = 'TokenFile'
opts.token_plugin = "TokenFile"
opts.token_source = opts.target_cfg
del opts.target_cfg
if sys.argv.count('--'):
if sys.argv.count("--"):
opts.wrap_cmd = args[1:]
else:
opts.wrap_cmd = None
if not websockifyserver.ssl and opts.ssl_target:
parser.error("SSL target requested and Python SSL module not loaded.");
parser.error("SSL target requested and Python SSL module not loaded.")
if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert)
@ -677,11 +810,11 @@ def websockify_init():
parser.error("Too few arguments")
arg = args.pop(0)
# Parse host:port and convert ports to numbers
if arg.count(':') > 0:
opts.listen_host, opts.listen_port = arg.rsplit(':', 1)
opts.listen_host = opts.listen_host.strip('[]')
if arg.count(":") > 0:
opts.listen_host, opts.listen_port = arg.rsplit(":", 1)
opts.listen_host = opts.listen_host.strip("[]")
else:
opts.listen_host, opts.listen_port = '', arg
opts.listen_host, opts.listen_port = "", arg
try:
opts.listen_port = int(opts.listen_port)
@ -697,9 +830,9 @@ def websockify_init():
if len(args) < 1:
parser.error("Too few arguments")
arg = args.pop(0)
if arg.count(':') > 0:
opts.target_host, opts.target_port = arg.rsplit(':', 1)
opts.target_host = opts.target_host.strip('[]')
if arg.count(":") > 0:
opts.target_host, opts.target_port = arg.rsplit(":", 1)
opts.target_host = opts.target_host.strip("[]")
else:
parser.error("Error parsing target")
@ -712,11 +845,10 @@ def websockify_init():
parser.error("Too many arguments")
if opts.token_plugin is not None:
if '.' not in opts.token_plugin:
opts.token_plugin = (
'websockify.token_plugins.%s' % opts.token_plugin)
if "." not in opts.token_plugin:
opts.token_plugin = "websockify.token_plugins.%s" % opts.token_plugin
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit(".", 1)
__import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
@ -726,10 +858,10 @@ def websockify_init():
del opts.token_source
if opts.auth_plugin is not None:
if '.' not in opts.auth_plugin:
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
if "." not in opts.auth_plugin:
opts.auth_plugin = "websockify.auth_plugins.%s" % opts.auth_plugin
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit(".", 1)
__import__(auth_plugin_module)
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
@ -759,32 +891,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.target_host = kwargs.pop("target_host", None)
self.target_port = kwargs.pop("target_port", None)
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
self.wrap_mode = kwargs.pop("wrap_mode", None)
self.unix_target = kwargs.pop("unix_target", None)
self.ssl_target = kwargs.pop("ssl_target", None)
self.token_plugin = kwargs.pop("token_plugin", None)
self.auth_plugin = kwargs.pop("auth_plugin", None)
self.heartbeat = kwargs.pop("heartbeat", None)
self.token_plugin = None
self.auth_plugin = None
self.daemon = False
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
listen_host = kwargs.pop("listen_host", "")
listen_port = kwargs.pop("listen_port", None)
web = kwargs.pop("web", "")
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
self.verbose = kwargs.pop("verbose", False)
record = kwargs.pop("record", "")
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.run_once = kwargs.pop("run_once", False)
self.handler_id = 0
for arg in kwargs.keys():
@ -795,12 +927,11 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
super().__init__((listen_host, listen_port), RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
super().process_request(request, client_address)
if __name__ == '__main__':
if __name__ == "__main__":
websockify_init()

View File

@ -1,19 +1,25 @@
#!/usr/bin/env python
'''
"""
Python WebSocket server base
Copyright 2011 Joel Martin
Copyright 2016-2018 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
'''
"""
import sys
from http.server import BaseHTTPRequestHandler, HTTPServer
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocket import (
WebSocket,
WebSocketWantReadError,
WebSocketWantWriteError,
)
class HttpWebSocket(WebSocket):
"""Class to glue websocket and http request functionality together"""
def __init__(self, request_handler):
super().__init__()
@ -62,8 +68,10 @@ class WebSocketRequestHandlerMixIn:
# Checks if it is a websocket request and redirects
self.do_GET = self._real_do_GET
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
if (
self.headers.get("upgrade")
and self.headers.get("upgrade").lower() == "websocket"
):
self.handle_upgrade()
else:
self.do_GET()
@ -100,11 +108,13 @@ class WebSocketRequestHandlerMixIn:
"""
pass
# Convenient ready made classes
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
BaseHTTPRequestHandler):
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler):
pass
class WebSocketServer(HTTPServer):
pass

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
Python WebSocket server base with support for "wss://" encryption.
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
@ -10,35 +10,39 @@ You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
"""
import os, sys, time, errno, signal, socket, select, logging
import multiprocessing
from http.server import SimpleHTTPRequestHandler
# Degraded functionality if these imports are missing
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'),
('resource', 'daemonizing is disabled')]:
for mod, msg in [
("ssl", "TLS/SSL/wss is disabled"),
("resource", "daemonizing is disabled"),
]:
try:
globals()[mod] = __import__(mod)
except ImportError:
globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg))
if sys.platform == 'win32':
if sys.platform == "win32":
# make sockets pickle-able/inheritable
import multiprocessing.reduction
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocketserver import WebSocketRequestHandlerMixIn
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
def select_subprotocol(self, protocols):
# Handle old websockify clients that still specify a sub-protocol
if 'binary' in protocols:
return 'binary'
if "binary" in protocols:
return "binary"
else:
return ''
return ""
# HTTP handler with WebSocket upgrade support
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
@ -56,6 +60,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
* run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename
"""
server_version = "WebSockify"
protocol_version = "HTTP/1.1"
@ -87,30 +92,33 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
super().__init__(req, addr, server)
def log_message(self, format, *args):
self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args))
self.logger.info(
"%s - - [%s] %s"
% (self.client_address[0], self.log_date_time_string(), format % args)
)
#
# WebSocketRequestHandler logging/output functions
#
def print_traffic(self, token="."):
""" Show traffic flow mode. """
"""Show traffic flow mode."""
if self.traffic:
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg, *args, **kwargs):
""" Output message with handler_id prefix. """
"""Output message with handler_id prefix."""
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
def vmsg(self, msg, *args, **kwargs):
""" Same as msg() but as debug. """
"""Same as msg() but as debug."""
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Same as msg() but as warning. """
"""Same as msg() but as warning."""
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
@ -118,19 +126,24 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
# Main WebSocketRequestHandler methods
#
def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already
"""Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued
frames will be sent. Returns True if any frames could not be
fully sent, in which case the caller should call again when
the socket is ready. """
the socket is ready."""
tdelta = int(time.time()*1000) - self.start_time
tdelta = int(time.time() * 1000) - self.start_time
if bufs:
for buf in bufs:
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr))
self.send_parts.append(buf)
@ -147,7 +160,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return False
def recv_frames(self):
""" Receive and decode WebSocket frames.
"""Receive and decode WebSocket frames.
Returns:
(bufs_list, closed_string)
@ -155,7 +168,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
closed = False
bufs = []
tdelta = int(time.time()*1000) - self.start_time
tdelta = int(time.time() * 1000) - self.start_time
while True:
try:
@ -165,15 +178,22 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
break
if buf is None:
closed = {'code': self.request.close_code,
'reason': self.request.close_reason}
closed = {
"code": self.request.close_code,
"reason": self.request.close_reason,
}
return bufs, closed
self.print_traffic("}")
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
bufstr = (
buf.decode("latin1")
.encode("unicode_escape")
.decode("ascii")
.replace("'", "\\'")
)
self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr))
bufs.append(buf)
@ -183,16 +203,16 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return bufs, closed
def send_close(self, code=1000, reason=''):
""" Send a WebSocket orderly close frame. """
def send_close(self, code=1000, reason=""):
"""Send a WebSocket orderly close frame."""
self.request.shutdown(socket.SHUT_RDWR, code, reason)
def send_pong(self, data=b''):
""" Send a WebSocket pong frame. """
def send_pong(self, data=b""):
"""Send a WebSocket pong frame."""
self.request.pong(data)
def send_ping(self, data=b''):
""" Send a WebSocket ping frame. """
def send_ping(self, data=b""):
"""Send a WebSocket ping frame."""
self.request.ping(data)
def handle_upgrade(self):
@ -208,7 +228,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.start_time = int(time.time()*1000)
self.start_time = int(time.time() * 1000)
# client_address is empty with, say, UNIX domain sockets
client_addr = ""
@ -224,17 +244,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection", client_addr,
self.stype)
if self.path != '/':
self.log_message("%s: %s WebSocket connection", client_addr, self.stype)
if self.path != "/":
self.log_message("%s: Path: '%s'", client_addr, self.path)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
fname = "%s.%s" % (self.record, self.handler_id)
self.log_message("opening record file: %s", fname)
self.rec = open(fname, 'w+')
self.rec = open(fname, "w+")
self.rec.write("var VNC_frame_data = [\n")
try:
@ -261,15 +279,17 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
return super().list_directory(path)
def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
"""Do something with a WebSockets client connection."""
raise Exception(
"WebSocketRequestHandler.new_websocket_client() must be overloaded"
)
def validate_connection(self):
""" Ensure that the connection has a valid token, and set the target. """
"""Ensure that the connection has a valid token, and set the target."""
pass
def auth_connection(self):
""" Ensure that the connection is authorized. """
"""Ensure that the connection is authorized."""
pass
def do_HEAD(self):
@ -296,12 +316,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
else:
super().handle()
def log_request(self, code='-', size='-'):
def log_request(self, code="-", size="-"):
if self.verbose:
super().log_request(code, size)
class WebSockifyServer():
class WebSockifyServer:
"""
WebSockets server class.
As an alternative, the standard library SocketServer can be used
@ -317,17 +337,38 @@ class WebSockifyServer():
class Terminate(Exception):
pass
def __init__(self, RequestHandlerClass, listen_fd=None,
listen_host='', listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', key_password=None, ssl_only=None,
verify_client=False, cafile=None,
daemon=False, record='', web='', web_auth=False,
def __init__(
self,
RequestHandlerClass,
listen_fd=None,
listen_host="",
listen_port=None,
source_is_ipv6=False,
verbose=False,
cert="",
key="",
key_password=None,
ssl_only=None,
verify_client=False,
cafile=None,
daemon=False,
record="",
web="",
web_auth=False,
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None):
run_once=False,
timeout=0,
idle_timeout=0,
traffic=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
ssl_ciphers=None,
ssl_options=0,
unix_listen=None,
unix_listen_mode=None,
):
# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
@ -366,7 +407,7 @@ class WebSockifyServer():
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.web = self.record = self.cafile = ''
self.web = self.record = self.cafile = ""
if key:
self.key = os.path.abspath(key)
if web:
@ -393,11 +434,12 @@ class WebSockifyServer():
elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen)
else:
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
self.msg(" - Listen on %s:%s", self.listen_host, self.listen_port)
if self.web:
if self.file_only:
self.msg(" - Web server (no directory listings). Web root: %s", self.web)
self.msg(
" - Web server (no directory listings). Web root: %s", self.web
)
else:
self.msg(" - Web server. Web root: %s", self.web)
if ssl:
@ -420,34 +462,45 @@ class WebSockifyServer():
@staticmethod
def get_logger():
return logging.getLogger("%s.%s" % (
WebSockifyServer.log_prefix,
WebSockifyServer.__class__.__name__))
return logging.getLogger(
"%s.%s" % (WebSockifyServer.log_prefix, WebSockifyServer.__class__.__name__)
)
@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6
def socket(
host,
port=None,
connect=False,
prefer_ipv6=False,
unix_socket=None,
unix_socket_mode=None,
unix_socket_listen=False,
use_ssl=False,
tcp_keepalive=True,
tcp_keepcnt=None,
tcp_keepidle=None,
tcp_keepintvl=None,
):
"""Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
"""
flags = 0
if host == '':
if host == "":
host = None
if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port")
if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded.");
raise Exception("SSL socket requested but Python SSL module not loaded.")
if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)")
if not connect:
flags = flags | socket.AI_PASSIVE
if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
addrs = socket.getaddrinfo(
host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags
)
if not addrs:
raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0])
@ -458,14 +511,11 @@ class WebSockifyServer():
if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
tcp_keepcnt)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
if tcp_keepidle:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
tcp_keepidle)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
if tcp_keepintvl:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
tcp_keepintvl)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
if connect:
sock.connect(addrs[0][4])
@ -497,8 +547,7 @@ class WebSockifyServer():
return sock
@staticmethod
def daemonize(keepfd=None, chdir='/'):
def daemonize(keepfd=None, chdir="/"):
if keepfd is None:
keepfd = []
@ -506,14 +555,16 @@ class WebSockifyServer():
if chdir:
os.chdir(chdir)
else:
os.chdir('/')
os.chdir("/")
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
if os.fork() > 0:
os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
if os.fork() > 0:
os._exit(0) # Parent exits
# Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN)
@ -521,14 +572,16 @@ class WebSockifyServer():
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
if maxfd == resource.RLIM_INFINITY:
maxfd = 256
for fd in reversed(range(maxfd)):
try:
if fd not in keepfd:
os.close(fd)
except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise
if exc.errno != errno.EBADF:
raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
@ -557,7 +610,7 @@ class WebSockifyServer():
# Peek, but do not read the data so that we have a opportunity
# to SSL wrap the socket first
handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake)
# self.msg("Handshake [%s]" % handshake)
if not handshake:
raise self.EClose("")
@ -567,8 +620,7 @@ class WebSockifyServer():
if not ssl:
raise self.EClose("SSL connection but no 'ssl' module")
if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found"
% self.cert)
raise self.EClose("SSL connection but '%s' not found" % self.cert)
retsock = None
try:
# create new-style SSL wrapping for extended features
@ -576,16 +628,16 @@ class WebSockifyServer():
if self.ssl_ciphers is not None:
context.set_ciphers(self.ssl_ciphers)
context.options = self.ssl_options
context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password)
context.load_cert_chain(
certfile=self.cert, keyfile=self.key, password=self.key_password
)
if self.verify_client:
context.verify_mode = ssl.CERT_REQUIRED
if self.cafile:
context.load_verify_locations(cafile=self.cafile)
else:
context.set_default_verify_paths()
retsock = context.wrap_socket(
sock,
server_side=True)
retsock = context.wrap_socket(sock, server_side=True)
except ssl.SSLError:
_, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF:
@ -618,28 +670,27 @@ class WebSockifyServer():
#
def msg(self, *args, **kwargs):
""" Output message as info """
"""Output message as info"""
self.logger.log(logging.INFO, *args, **kwargs)
def vmsg(self, *args, **kwargs):
""" Same as msg() but as debug. """
"""Same as msg() but as debug."""
self.logger.log(logging.DEBUG, *args, **kwargs)
def warn(self, *args, **kwargs):
""" Same as msg() but as warning. """
"""Same as msg() but as warning."""
self.logger.log(logging.WARN, *args, **kwargs)
#
# Events that can/should be overridden in sub-classes
#
def started(self):
""" Called after WebSockets startup """
"""Called after WebSockets startup"""
self.vmsg("WebSockets server started")
def poll(self):
""" Run periodically while waiting for connections. """
#self.vmsg("Running poll()")
"""Run periodically while waiting for connections."""
# self.vmsg("Running poll()")
pass
def terminate(self):
@ -661,7 +712,7 @@ class WebSockifyServer():
while result[0]:
self.vmsg("Reaped child process %s" % result[0])
result = os.waitpid(-1, os.WNOHANG)
except (OSError):
except OSError:
pass
def do_SIGINT(self, sig, stack):
@ -675,7 +726,7 @@ class WebSockifyServer():
self.terminate()
def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
"""Do something with a WebSockets client connection."""
# handler process
client = None
try:
@ -693,7 +744,6 @@ class WebSockifyServer():
self.msg("handler exception: %s" % str(exc))
self.vmsg("exception", exc_info=True)
finally:
if client and client != startsock:
# Close the SSL wrapped socket
# Original socket closed by caller
@ -721,19 +771,27 @@ class WebSockifyServer():
try:
if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
lsock = socket.fromfd(
self.listen_fd, socket.AF_INET, socket.SOCK_STREAM
)
elif self.unix_listen != None:
lsock = self.socket(host=None,
lsock = self.socket(
host=None,
unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True)
unix_socket_listen=True,
)
else:
lsock = self.socket(self.listen_host, self.listen_port, False,
lsock = self.socket(
self.listen_host,
self.listen_port,
False,
self.prefer_ipv6,
tcp_keepalive=self.tcp_keepalive,
tcp_keepcnt=self.tcp_keepcnt,
tcp_keepidle=self.tcp_keepidle,
tcp_keepintvl=self.tcp_keepintvl)
tcp_keepintvl=self.tcp_keepintvl,
)
except OSError as e:
self.msg("Openening socket failed: %s", str(e))
self.vmsg("exception", exc_info=True)
@ -751,13 +809,13 @@ class WebSockifyServer():
signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
}
if getattr(signal, 'SIGCHLD', None) is not None:
if getattr(signal, "SIGCHLD", None) is not None:
original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM)
# make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD
if getattr(signal, 'SIGCHLD', None) is not None:
if getattr(signal, "SIGCHLD", None) is not None:
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time
@ -774,8 +832,7 @@ class WebSockifyServer():
time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s'
% self.timeout)
self.msg("listener exit due to --timeout %s" % self.timeout)
break
if self.idle_timeout:
@ -787,8 +844,10 @@ class WebSockifyServer():
last_active_time = time.time()
if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s'
% self.idle_timeout)
self.msg(
"listener exit due to --idle-timeout %s"
% self.idle_timeout
)
break
try:
@ -799,16 +858,16 @@ class WebSockifyServer():
startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None:
address = [ self.unix_listen ]
address = [self.unix_listen]
else:
continue
except self.Terminate:
raise
except Exception:
_, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'):
if hasattr(exc, "errno"):
err = exc.errno
elif hasattr(exc, 'args'):
elif hasattr(exc, "args"):
err = exc.args[0]
else:
err = exc[0]
@ -821,15 +880,14 @@ class WebSockifyServer():
if self.run_once:
# Run in same process if run_once
self.top_new_client(startsock, address)
if self.ws_connection :
self.msg('%s: exiting due to --run-once'
% address[0])
if self.ws_connection:
self.msg("%s: exiting due to --run-once" % address[0])
break
else:
self.vmsg('%s: new handler Process' % address[0])
self.vmsg("%s: new handler Process" % address[0])
p = multiprocessing.Process(
target=self.top_new_client,
args=(startsock, address))
target=self.top_new_client, args=(startsock, address)
)
p.start()
# child will not return
@ -857,12 +915,11 @@ class WebSockifyServer():
startsock.close()
finally:
# Close listen port
self.vmsg("Closing socket listening at %s:%s",
self.listen_host, self.listen_port)
self.vmsg(
"Closing socket listening at %s:%s", self.listen_host, self.listen_port
)
lsock.close()
# Restore signals
for sig, func in original_signals.items():
signal.signal(sig, func)