Merge 3eca8ad195 into 15489fb72e
This commit is contained in:
commit
c57dc765d3
|
|
@ -0,0 +1,31 @@
|
|||
name: Lint
|
||||
on: [push, pull_request]
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- 3.12
|
||||
job:
|
||||
- mypy .
|
||||
- ruff format --check .
|
||||
- ruff check .
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras --all-groups
|
||||
|
||||
- run: uv run ${{ matrix.job }}
|
||||
|
|
@ -10,3 +10,6 @@ target.cfg.d
|
|||
build/
|
||||
dist/
|
||||
*.egg-info
|
||||
|
||||
# Intellij IDEs
|
||||
.idea
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
[project]
|
||||
name = "websockify"
|
||||
version = "0.13.0"
|
||||
description = "Websockify."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Joel Martin", email = "github@martintribe.org" }
|
||||
]
|
||||
requires-python = ">=3.6"
|
||||
dependencies = [
|
||||
"jwcrypto>=1.5.1",
|
||||
"numpy>=1.19.0",
|
||||
"redis>=4.3.0",
|
||||
"requests>=2.27.0",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
keywords = ["noVNC", "websockify"]
|
||||
license = { text = "LGPLv3" }
|
||||
|
||||
[project.urls]
|
||||
Repository ="https://github.com/novnc/websockify"
|
||||
|
||||
[project.scripts]
|
||||
websockify = "websockify.websocketproxy:websockify_init"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [
|
||||
"websockify",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy>=0.971",
|
||||
"ruff>=0.0.17",
|
||||
]
|
||||
80
setup.py
80
setup.py
|
|
@ -1,44 +1,44 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
version = '0.13.0'
|
||||
name = 'websockify'
|
||||
long_description = open("README.md").read() + "\n" + \
|
||||
open("CHANGES.txt").read() + "\n"
|
||||
version = "0.13.0"
|
||||
name = "websockify"
|
||||
long_description = open("README.md").read() + "\n" + open("CHANGES.txt").read() + "\n"
|
||||
|
||||
setup(name=name,
|
||||
version=version,
|
||||
description="Websockify.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
classifiers=[
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
keywords='noVNC websockify',
|
||||
license='LGPLv3',
|
||||
url="https://github.com/novnc/websockify",
|
||||
author="Joel Martin",
|
||||
author_email="github@martintribe.org",
|
||||
|
||||
packages=['websockify'],
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
'numpy', 'requests',
|
||||
'jwcrypto',
|
||||
'redis',
|
||||
],
|
||||
zip_safe=False,
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'websockify = websockify.websocketproxy:websockify_init',
|
||||
setup(
|
||||
name=name,
|
||||
version=version,
|
||||
description="Websockify.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
classifiers=[
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
keywords="noVNC websockify",
|
||||
license="LGPLv3",
|
||||
url="https://github.com/novnc/websockify",
|
||||
author="Joel Martin",
|
||||
author_email="github@martintribe.org",
|
||||
packages=["websockify"],
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"requests",
|
||||
"jwcrypto",
|
||||
"redis",
|
||||
],
|
||||
zip_safe=False,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"websockify = websockify.websocketproxy:websockify_init",
|
||||
]
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
A WebSocket server that echos back whatever it receives from the client.
|
||||
Copyright 2010 Joel Martin
|
||||
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
|
||||
|
|
@ -8,16 +8,19 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
|
|||
You can make a cert/key with openssl using:
|
||||
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
|
||||
as taken from http://docs.python.org/dev/library/ssl.html#certificates
|
||||
'''
|
||||
"""
|
||||
|
||||
import os, sys, select, optparse, logging
|
||||
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
|
||||
|
||||
|
||||
class WebSocketEcho(WebSockifyRequestHandler):
|
||||
"""
|
||||
WebSockets server that echos back whatever is received from the
|
||||
client. """
|
||||
client."""
|
||||
|
||||
buffer_size = 8096
|
||||
|
||||
def new_websocket_client(self):
|
||||
|
|
@ -33,9 +36,11 @@ class WebSocketEcho(WebSockifyRequestHandler):
|
|||
while True:
|
||||
wlist = []
|
||||
|
||||
if cqueue or c_pend: wlist.append(self.request)
|
||||
if cqueue or c_pend:
|
||||
wlist.append(self.request)
|
||||
ins, outs, excepts = select.select(rlist, wlist, [], 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
if self.request in outs:
|
||||
# Send queued target data to the client
|
||||
|
|
@ -50,20 +55,27 @@ class WebSocketEcho(WebSockifyRequestHandler):
|
|||
if closed:
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = optparse.OptionParser(usage="%prog [options] listen_port")
|
||||
parser.add_option("--verbose", "-v", action="store_true",
|
||||
help="verbose messages and per frame traffic")
|
||||
parser.add_option("--cert", default="self.pem",
|
||||
help="SSL certificate file")
|
||||
parser.add_option("--key", default=None,
|
||||
help="SSL key file (if separate from cert)")
|
||||
parser.add_option("--ssl-only", action="store_true",
|
||||
help="disallow non-encrypted connections")
|
||||
parser.add_option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="verbose messages and per frame traffic",
|
||||
)
|
||||
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
|
||||
parser.add_option(
|
||||
"--key", default=None, help="SSL key file (if separate from cert)"
|
||||
)
|
||||
parser.add_option(
|
||||
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
|
||||
)
|
||||
(opts, args) = parser.parse_args()
|
||||
|
||||
try:
|
||||
if len(args) != 1: raise ValueError
|
||||
if len(args) != 1:
|
||||
raise ValueError
|
||||
opts.listen_port = int(args[0])
|
||||
except ValueError:
|
||||
parser.error("Invalid arguments")
|
||||
|
|
@ -73,4 +85,3 @@ if __name__ == '__main__':
|
|||
opts.web = "."
|
||||
server = WebSockifyServer(WebSocketEcho, **opts.__dict__)
|
||||
server.start_server()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,12 @@ import sys
|
|||
import optparse
|
||||
import select
|
||||
|
||||
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
|
||||
from websockify.websocket import WebSocket, \
|
||||
WebSocketWantReadError, WebSocketWantWriteError
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from websockify.websocket import (
|
||||
WebSocket,
|
||||
WebSocketWantReadError,
|
||||
WebSocketWantWriteError,
|
||||
)
|
||||
|
||||
parser = optparse.OptionParser(usage="%prog URL")
|
||||
(opts, args) = parser.parse_args()
|
||||
|
|
@ -22,19 +25,23 @@ print("Connecting to %s..." % URL)
|
|||
sock.connect(URL)
|
||||
print("Connected.")
|
||||
|
||||
|
||||
def send(msg):
|
||||
while True:
|
||||
try:
|
||||
sock.sendmsg(msg)
|
||||
break
|
||||
except WebSocketWantReadError:
|
||||
msg = ''
|
||||
msg = ""
|
||||
ins, outs, excepts = select.select([sock], [], [])
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
except WebSocketWantWriteError:
|
||||
msg = ''
|
||||
msg = ""
|
||||
ins, outs, excepts = select.select([], [sock], [])
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
|
||||
def read():
|
||||
while True:
|
||||
|
|
@ -42,10 +49,13 @@ def read():
|
|||
return sock.recvmsg()
|
||||
except WebSocketWantReadError:
|
||||
ins, outs, excepts = select.select([sock], [], [])
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
except WebSocketWantWriteError:
|
||||
ins, outs, excepts = select.select([], [sock], [])
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
|
||||
counter = 1
|
||||
while True:
|
||||
|
|
@ -56,7 +66,8 @@ while True:
|
|||
|
||||
while True:
|
||||
ins, outs, excepts = select.select([sock], [], [], 1.0)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
if ins == []:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
echo.py
|
||||
echo.py
|
||||
|
|
|
|||
|
|
@ -1,32 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
WebSocket server-side load test program. Sends and receives traffic
|
||||
that has a random payload (length and content) that is checksummed and
|
||||
given a sequence number. Any errors are reported and counted.
|
||||
'''
|
||||
"""
|
||||
|
||||
import sys, os, select, random, time, optparse, logging
|
||||
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler
|
||||
|
||||
class WebSocketLoadServer(WebSockifyServer):
|
||||
|
||||
class WebSocketLoadServer(WebSockifyServer):
|
||||
recv_cnt = 0
|
||||
send_cnt = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.delay = kwargs.pop('delay')
|
||||
self.delay = kwargs.pop("delay")
|
||||
|
||||
WebSockifyServer.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
class WebSocketLoad(WebSockifyRequestHandler):
|
||||
|
||||
max_packet_size = 10000
|
||||
|
||||
def new_websocket_client(self):
|
||||
print "Prepopulating random array"
|
||||
print("Prepopulating random array")
|
||||
self.rand_array = []
|
||||
for i in range(0, self.max_packet_size):
|
||||
self.rand_array.append(random.randint(0, 9))
|
||||
|
|
@ -37,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
|
||||
self.responder(self.request)
|
||||
|
||||
print "accumulated errors:", self.errors
|
||||
print("accumulated errors:", self.errors)
|
||||
self.errors = 0
|
||||
|
||||
def responder(self, client):
|
||||
|
|
@ -49,7 +49,8 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
|
||||
while True:
|
||||
ins, outs, excepts = select.select(socks, socks, socks, 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
if client in ins:
|
||||
frames, closed = self.recv_frames()
|
||||
|
|
@ -57,7 +58,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
err = self.check(frames)
|
||||
if err:
|
||||
self.errors = self.errors + 1
|
||||
print err
|
||||
print(err)
|
||||
|
||||
if closed:
|
||||
break
|
||||
|
|
@ -73,24 +74,22 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
|
||||
def generate(self):
|
||||
length = random.randint(10, self.max_packet_size)
|
||||
numlist = self.rand_array[self.max_packet_size-length:]
|
||||
numlist = self.rand_array[self.max_packet_size - length :]
|
||||
# Error in length
|
||||
#numlist.append(5)
|
||||
# numlist.append(5)
|
||||
chksum = sum(numlist)
|
||||
# Error in checksum
|
||||
#numlist[0] = 5
|
||||
nums = "".join( [str(n) for n in numlist] )
|
||||
# numlist[0] = 5
|
||||
nums = "".join([str(n) for n in numlist])
|
||||
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
|
||||
self.send_cnt += 1
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def check(self, frames):
|
||||
|
||||
err = ""
|
||||
for data in frames:
|
||||
if data.count('$') > 1:
|
||||
if data.count("$") > 1:
|
||||
raise Exception("Multiple parts within single packet")
|
||||
if len(data) == 0:
|
||||
self.traffic("_")
|
||||
|
|
@ -101,12 +100,12 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
continue
|
||||
|
||||
try:
|
||||
cnt, length, chksum, nums = data[1:-1].split(':')
|
||||
cnt = int(cnt)
|
||||
cnt, length, chksum, nums = data[1:-1].split(":")
|
||||
cnt = int(cnt)
|
||||
length = int(length)
|
||||
chksum = int(chksum)
|
||||
except ValueError:
|
||||
print "\n<BOF>" + repr(data) + "<EOF>"
|
||||
print("\n<BOF>" + repr(data) + "<EOF>")
|
||||
err += "Invalid data format\n"
|
||||
continue
|
||||
|
||||
|
|
@ -131,27 +130,37 @@ class WebSocketLoad(WebSockifyRequestHandler):
|
|||
real_chksum += int(num)
|
||||
|
||||
if real_chksum != chksum:
|
||||
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
|
||||
err += "Expected checksum %d but real chksum is %d\n" % (
|
||||
chksum,
|
||||
real_chksum,
|
||||
)
|
||||
return err
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = optparse.OptionParser(usage="%prog [options] listen_port")
|
||||
parser.add_option("--verbose", "-v", action="store_true",
|
||||
help="verbose messages and per frame traffic")
|
||||
parser.add_option("--cert", default="self.pem",
|
||||
help="SSL certificate file")
|
||||
parser.add_option("--key", default=None,
|
||||
help="SSL key file (if separate from cert)")
|
||||
parser.add_option("--ssl-only", action="store_true",
|
||||
help="disallow non-encrypted connections")
|
||||
parser.add_option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="verbose messages and per frame traffic",
|
||||
)
|
||||
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
|
||||
parser.add_option(
|
||||
"--key", default=None, help="SSL key file (if separate from cert)"
|
||||
)
|
||||
parser.add_option(
|
||||
"--ssl-only", action="store_true", help="disallow non-encrypted connections"
|
||||
)
|
||||
(opts, args) = parser.parse_args()
|
||||
|
||||
try:
|
||||
if len(args) != 1: raise ValueError
|
||||
if len(args) != 1:
|
||||
raise ValueError
|
||||
opts.listen_port = int(args[0])
|
||||
|
||||
if len(args) not in [1,2]: raise ValueError
|
||||
if len(args) not in [1, 2]:
|
||||
raise ValueError
|
||||
opts.listen_port = int(args[0])
|
||||
if len(args) == 2:
|
||||
opts.delay = int(args[1])
|
||||
|
|
@ -165,4 +174,3 @@ if __name__ == '__main__':
|
|||
opts.web = "."
|
||||
server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
|
||||
server.start_server()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,33 @@
|
|||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
""" Unit tests for Authentication plugins"""
|
||||
"""Unit tests for Authentication plugins"""
|
||||
|
||||
from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError
|
||||
import unittest
|
||||
|
||||
|
||||
class BasicHTTPAuthTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.plugin = BasicHTTPAuth('Aladdin:open sesame')
|
||||
self.plugin = BasicHTTPAuth("Aladdin:open sesame")
|
||||
|
||||
def test_no_auth(self):
|
||||
headers = {}
|
||||
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
|
||||
self.assertRaises(
|
||||
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
|
||||
)
|
||||
|
||||
def test_invalid_password(self):
|
||||
headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'}
|
||||
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
|
||||
headers = {"Authorization": "Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0"}
|
||||
self.assertRaises(
|
||||
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
|
||||
)
|
||||
|
||||
def test_valid_password(self):
|
||||
headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='}
|
||||
self.plugin.authenticate(headers, 'localhost', '1234')
|
||||
headers = {"Authorization": "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="}
|
||||
self.plugin.authenticate(headers, "localhost", "1234")
|
||||
|
||||
def test_garbage_auth(self):
|
||||
headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'}
|
||||
self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234')
|
||||
headers = {"Authorization": "Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
|
||||
self.assertRaises(
|
||||
AuthenticationError, self.plugin.authenticate, headers, "localhost", "1234"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,80 +1,93 @@
|
|||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
""" Unit tests for Token plugins"""
|
||||
"""Unit tests for Token plugins"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
from jwcrypto import jwt, jwk
|
||||
|
||||
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
|
||||
from websockify.token_plugins import (
|
||||
parse_source_args,
|
||||
ReadOnlyTokenFile,
|
||||
JWTTokenApi,
|
||||
TokenRedis,
|
||||
)
|
||||
|
||||
|
||||
class ParseSourceArgumentsTestCase(unittest.TestCase):
|
||||
def test_parameterized(self):
|
||||
params = [
|
||||
('', ['']),
|
||||
(':', ['', '']),
|
||||
('::', ['', '', '']),
|
||||
("", [""]),
|
||||
(":", ["", ""]),
|
||||
("::", ["", "", ""]),
|
||||
('"', ['"']),
|
||||
('""', ['""']),
|
||||
('"""', ['"""']),
|
||||
('"localhost"', ['localhost']),
|
||||
('"localhost":', ['localhost', '']),
|
||||
('"localhost"::', ['localhost', '', '']),
|
||||
('"local:host"', ['local:host']),
|
||||
('"local:host:"pass"', ['"local', 'host', "pass"]),
|
||||
('"local":"host"', ['local', 'host']),
|
||||
('"local":host"', ['local', 'host"']),
|
||||
('localhost:6379:1:pass"word:"my-app-namespace:dev"',
|
||||
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']),
|
||||
('"localhost"', ["localhost"]),
|
||||
('"localhost":', ["localhost", ""]),
|
||||
('"localhost"::', ["localhost", "", ""]),
|
||||
('"local:host"', ["local:host"]),
|
||||
('"local:host:"pass"', ['"local', "host", "pass"]),
|
||||
('"local":"host"', ["local", "host"]),
|
||||
('"local":host"', ["local", 'host"']),
|
||||
(
|
||||
'localhost:6379:1:pass"word:"my-app-namespace:dev"',
|
||||
["localhost", "6379", "1", 'pass"word', "my-app-namespace:dev"],
|
||||
),
|
||||
]
|
||||
for src, args in params:
|
||||
self.assertEqual(args, parse_source_args(src))
|
||||
|
||||
|
||||
class ReadOnlyTokenFileTestCase(unittest.TestCase):
|
||||
patch('os.path.isdir', MagicMock(return_value=False))
|
||||
patch("os.path.isdir", MagicMock(return_value=False))
|
||||
|
||||
def test_empty(self):
|
||||
plugin = ReadOnlyTokenFile('configfile')
|
||||
plugin = ReadOnlyTokenFile("configfile")
|
||||
|
||||
config = ""
|
||||
pyopen = mock_open(read_data=config)
|
||||
|
||||
with patch("websockify.token_plugins.open", pyopen, create=True):
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
pyopen.assert_called_once_with('configfile')
|
||||
pyopen.assert_called_once_with("configfile")
|
||||
self.assertIsNone(result)
|
||||
|
||||
patch('os.path.isdir', MagicMock(return_value=False))
|
||||
patch("os.path.isdir", MagicMock(return_value=False))
|
||||
|
||||
def test_simple(self):
|
||||
plugin = ReadOnlyTokenFile('configfile')
|
||||
plugin = ReadOnlyTokenFile("configfile")
|
||||
|
||||
config = "testhost: remote_host:remote_port"
|
||||
pyopen = mock_open(read_data=config)
|
||||
|
||||
with patch("websockify.token_plugins.open", pyopen, create=True):
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
pyopen.assert_called_once_with('configfile')
|
||||
pyopen.assert_called_once_with("configfile")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
patch('os.path.isdir', MagicMock(return_value=False))
|
||||
patch("os.path.isdir", MagicMock(return_value=False))
|
||||
|
||||
def test_tabs(self):
|
||||
plugin = ReadOnlyTokenFile('configfile')
|
||||
plugin = ReadOnlyTokenFile("configfile")
|
||||
|
||||
config = "testhost:\tremote_host:remote_port"
|
||||
pyopen = mock_open(read_data=config)
|
||||
|
||||
with patch("websockify.token_plugins.open", pyopen, create=True):
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
pyopen.assert_called_once_with('configfile')
|
||||
pyopen.assert_called_once_with("configfile")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
|
||||
class JWSTokenTestCase(unittest.TestCase):
|
||||
def test_asymmetric_jws_token_plugin(self):
|
||||
plugin = JWTTokenApi("./tests/fixtures/public.pem")
|
||||
|
|
@ -82,7 +95,9 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
key = jwk.JWK()
|
||||
private_key = open("./tests/fixtures/private.pem", "rb").read()
|
||||
key.import_from_pem(private_key)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
|
||||
result = plugin.lookup(jwt_token.serialize())
|
||||
|
|
@ -97,21 +112,26 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
key = jwk.JWK()
|
||||
private_key = open("./tests/fixtures/private.pem", "rb").read()
|
||||
key.import_from_pem(private_key)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
|
||||
result = plugin.lookup(jwt_token.serialize())
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('time.time')
|
||||
@patch("time.time")
|
||||
def test_jwt_valid_time(self, mock_time):
|
||||
plugin = JWTTokenApi("./tests/fixtures/public.pem")
|
||||
|
||||
key = jwk.JWK()
|
||||
private_key = open("./tests/fixtures/private.pem", "rb").read()
|
||||
key.import_from_pem(private_key)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"},
|
||||
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
mock_time.return_value = 150
|
||||
|
||||
|
|
@ -121,14 +141,17 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('time.time')
|
||||
@patch("time.time")
|
||||
def test_jwt_early_time(self, mock_time):
|
||||
plugin = JWTTokenApi("./tests/fixtures/public.pem")
|
||||
|
||||
key = jwk.JWK()
|
||||
private_key = open("./tests/fixtures/private.pem", "rb").read()
|
||||
key.import_from_pem(private_key)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"},
|
||||
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
mock_time.return_value = 50
|
||||
|
||||
|
|
@ -136,14 +159,17 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('time.time')
|
||||
@patch("time.time")
|
||||
def test_jwt_late_time(self, mock_time):
|
||||
plugin = JWTTokenApi("./tests/fixtures/public.pem")
|
||||
|
||||
key = jwk.JWK()
|
||||
private_key = open("./tests/fixtures/private.pem", "rb").read()
|
||||
key.import_from_pem(private_key)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 })
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"},
|
||||
{"host": "remote_host", "port": "remote_port", "nbf": 100, "exp": 200},
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
mock_time.return_value = 250
|
||||
|
||||
|
|
@ -156,8 +182,10 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
|
||||
secret = open("./tests/fixtures/symmetric.key").read()
|
||||
key = jwk.JWK()
|
||||
key.import_key(kty="oct",k=secret)
|
||||
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
|
||||
key.import_key(kty="oct", k=secret)
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
|
||||
result = plugin.lookup(jwt_token.serialize())
|
||||
|
|
@ -171,8 +199,10 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
|
||||
secret = open("./tests/fixtures/symmetric.key").read()
|
||||
key = jwk.JWK()
|
||||
key.import_key(kty="oct",k=secret)
|
||||
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
|
||||
key.import_key(kty="oct", k=secret)
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "HS256"}, {"host": "remote_host", "port": "remote_port"}
|
||||
)
|
||||
jwt_token.make_signed_token(key)
|
||||
|
||||
result = plugin.lookup(jwt_token.serialize())
|
||||
|
|
@ -188,10 +218,14 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
public_key_data = open("./tests/fixtures/public.pem", "rb").read()
|
||||
private_key.import_from_pem(private_key_data)
|
||||
public_key.import_from_pem(public_key_data)
|
||||
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
|
||||
jwt_token = jwt.JWT(
|
||||
{"alg": "RS256"}, {"host": "remote_host", "port": "remote_port"}
|
||||
)
|
||||
jwt_token.make_signed_token(private_key)
|
||||
jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
|
||||
claims=jwt_token.serialize())
|
||||
jwe_token = jwt.JWT(
|
||||
header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
|
||||
claims=jwt_token.serialize(),
|
||||
)
|
||||
jwe_token.make_encrypted_token(public_key)
|
||||
|
||||
result = plugin.lookup(jwt_token.serialize())
|
||||
|
|
@ -200,103 +234,104 @@ class JWSTokenTestCase(unittest.TestCase):
|
|||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
|
||||
class TokenRedisTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
patcher = patch.dict(sys.modules, {'redis': MagicMock()})
|
||||
patcher = patch.dict(sys.modules, {"redis": MagicMock()})
|
||||
patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_empty(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = None
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_simple(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = b'{"host": "remote_host:remote_port"}'
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_json_token_with_spaces(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = b' {"host": "remote_host:remote_port"} '
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_text_token(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = b'remote_host:remote_port'
|
||||
instance.get.return_value = b"remote_host:remote_port"
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_text_token_with_spaces(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = b' remote_host:remote_port '
|
||||
instance.get.return_value = b" remote_host:remote_port "
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_invalid_token(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get.return_value = b'{"host": "remote_host:remote_port" '
|
||||
|
||||
result = plugin.lookup('testhost')
|
||||
result = plugin.lookup("testhost")
|
||||
|
||||
instance.get.assert_called_once_with('testhost')
|
||||
instance.get.assert_called_once_with("testhost")
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_token_without_namespace(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
token = 'testhost'
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
token = "testhost"
|
||||
|
||||
def mock_redis_get(key):
|
||||
self.assertEqual(key, token)
|
||||
return b'remote_host:remote_port'
|
||||
return b"remote_host:remote_port"
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get = mock_redis_get
|
||||
|
|
@ -304,17 +339,17 @@ class TokenRedisTestCase(unittest.TestCase):
|
|||
result = plugin.lookup(token)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch("redis.Redis")
|
||||
def test_token_with_namespace(self, mock_redis):
|
||||
plugin = TokenRedis('127.0.0.1:1234:::namespace')
|
||||
token = 'testhost'
|
||||
plugin = TokenRedis("127.0.0.1:1234:::namespace")
|
||||
token = "testhost"
|
||||
|
||||
def mock_redis_get(key):
|
||||
self.assertEqual(key, "namespace:" + token)
|
||||
return b'remote_host:remote_port'
|
||||
return b"remote_host:remote_port"
|
||||
|
||||
instance = mock_redis.return_value
|
||||
instance.get = mock_redis_get
|
||||
|
|
@ -322,103 +357,103 @@ class TokenRedisTestCase(unittest.TestCase):
|
|||
result = plugin.lookup(token)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[0], 'remote_host')
|
||||
self.assertEqual(result[1], 'remote_port')
|
||||
self.assertEqual(result[0], "remote_host")
|
||||
self.assertEqual(result[1], "remote_port")
|
||||
|
||||
def test_src_only_host(self):
|
||||
plugin = TokenRedis('127.0.0.1')
|
||||
plugin = TokenRedis("127.0.0.1")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_port(self):
|
||||
plugin = TokenRedis('127.0.0.1:1234')
|
||||
plugin = TokenRedis("127.0.0.1:1234")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 1234)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_port_db(self):
|
||||
plugin = TokenRedis('127.0.0.1:1234:2')
|
||||
plugin = TokenRedis("127.0.0.1:1234:2")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 1234)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_port_db_pass(self):
|
||||
plugin = TokenRedis('127.0.0.1:1234:2:verysecret')
|
||||
plugin = TokenRedis("127.0.0.1:1234:2:verysecret")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 1234)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, 'verysecret')
|
||||
self.assertEqual(plugin._password, "verysecret")
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_port_db_pass_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace')
|
||||
plugin = TokenRedis("127.0.0.1:1234:2:verysecret:namespace")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 1234)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, 'verysecret')
|
||||
self.assertEqual(plugin._password, "verysecret")
|
||||
self.assertEqual(plugin._namespace, "namespace:")
|
||||
|
||||
def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1:::verysecret')
|
||||
plugin = TokenRedis("127.0.0.1:::verysecret")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, 'verysecret')
|
||||
self.assertEqual(plugin._password, "verysecret")
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::::')
|
||||
plugin = TokenRedis("127.0.0.1::::")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1:::')
|
||||
plugin = TokenRedis("127.0.0.1:::")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::')
|
||||
plugin = TokenRedis("127.0.0.1::")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1:')
|
||||
plugin = TokenRedis("127.0.0.1:")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::::namespace')
|
||||
plugin = TokenRedis("127.0.0.1::::namespace")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
|
|
@ -427,43 +462,43 @@ class TokenRedisTestCase(unittest.TestCase):
|
|||
def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "ns1:ns2:")
|
||||
|
||||
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::2')
|
||||
plugin = TokenRedis("127.0.0.1::2")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, None)
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_port_empty_db_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1:1234::verysecret')
|
||||
plugin = TokenRedis("127.0.0.1:1234::verysecret")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 1234)
|
||||
self.assertEqual(plugin._db, 0)
|
||||
self.assertEqual(plugin._password, 'verysecret')
|
||||
self.assertEqual(plugin._password, "verysecret")
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_db_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::2:verysecret')
|
||||
plugin = TokenRedis("127.0.0.1::2:verysecret")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, 'verysecret')
|
||||
self.assertEqual(plugin._password, "verysecret")
|
||||
self.assertEqual(plugin._namespace, "")
|
||||
|
||||
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
|
||||
plugin = TokenRedis('127.0.0.1::2:')
|
||||
plugin = TokenRedis("127.0.0.1::2:")
|
||||
|
||||
self.assertEqual(plugin._server, '127.0.0.1')
|
||||
self.assertEqual(plugin._server, "127.0.0.1")
|
||||
self.assertEqual(plugin._port, 6379)
|
||||
self.assertEqual(plugin._db, 2)
|
||||
self.assertEqual(plugin._password, None)
|
||||
|
|
|
|||
|
|
@ -14,199 +14,268 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
""" Unit tests for websocket """
|
||||
"""Unit tests for websocket"""
|
||||
|
||||
import unittest
|
||||
from websockify import websocket
|
||||
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self):
|
||||
self.data = b''
|
||||
self.data = b""
|
||||
|
||||
def send(self, buf):
|
||||
self.data += buf
|
||||
return len(buf)
|
||||
|
||||
|
||||
class AcceptTestCase(unittest.TestCase):
|
||||
def test_success(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data)
|
||||
self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data)
|
||||
self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data)
|
||||
ws.accept(
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
|
||||
self.assertTrue(b"\r\nUpgrade: websocket\r\n" in sock.data)
|
||||
self.assertTrue(b"\r\nConnection: Upgrade\r\n" in sock.data)
|
||||
self.assertTrue(
|
||||
b"\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n" in sock.data
|
||||
)
|
||||
|
||||
def test_bad_version(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '5',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '20',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{"upgrade": "websocket", "Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q=="},
|
||||
)
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "5",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "20",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
|
||||
def test_bad_upgrade(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket2',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket2",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
|
||||
def test_missing_key(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13'})
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{"upgrade": "websocket", "Sec-WebSocket-Version": "13"},
|
||||
)
|
||||
|
||||
def test_protocol(self):
|
||||
class ProtoSocket(websocket.WebSocket):
|
||||
def select_subprotocol(self, protocol):
|
||||
return 'gazonk'
|
||||
return "gazonk"
|
||||
|
||||
ws = ProtoSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data)
|
||||
ws.accept(
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
"Sec-WebSocket-Protocol": "foobar gazonk",
|
||||
},
|
||||
)
|
||||
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
|
||||
self.assertTrue(b"\r\nSec-WebSocket-Protocol: gazonk\r\n" in sock.data)
|
||||
|
||||
def test_no_protocol(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data)
|
||||
ws.accept(
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
self.assertEqual(sock.data[:13], b"HTTP/1.1 101 ")
|
||||
self.assertFalse(b"\r\nSec-WebSocket-Protocol:" in sock.data)
|
||||
|
||||
def test_missing_protocol(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
"Sec-WebSocket-Protocol": "foobar gazonk",
|
||||
},
|
||||
)
|
||||
|
||||
def test_protocol(self):
|
||||
class ProtoSocket(websocket.WebSocket):
|
||||
def select_subprotocol(self, protocol):
|
||||
return 'oddball'
|
||||
return "oddball"
|
||||
|
||||
ws = ProtoSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
ws.accept,
|
||||
sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
"Sec-WebSocket-Protocol": "foobar gazonk",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PingPongTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ws = websocket.WebSocket()
|
||||
self.sock = FakeSocket()
|
||||
self.ws.accept(self.sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.sock.data = b''
|
||||
self.ws.accept(
|
||||
self.sock,
|
||||
{
|
||||
"upgrade": "websocket",
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Sec-WebSocket-Key": "DKURYVK9cRFul1vOZVA56Q==",
|
||||
},
|
||||
)
|
||||
self.assertEqual(self.sock.data[:13], b"HTTP/1.1 101 ")
|
||||
self.sock.data = b""
|
||||
|
||||
def test_ping(self):
|
||||
self.ws.ping()
|
||||
self.assertEqual(self.sock.data, b'\x89\x00')
|
||||
self.assertEqual(self.sock.data, b"\x89\x00")
|
||||
|
||||
def test_pong(self):
|
||||
self.ws.pong()
|
||||
self.assertEqual(self.sock.data, b'\x8a\x00')
|
||||
self.assertEqual(self.sock.data, b"\x8a\x00")
|
||||
|
||||
def test_ping_data(self):
|
||||
self.ws.ping(b'foo')
|
||||
self.assertEqual(self.sock.data, b'\x89\x03foo')
|
||||
self.ws.ping(b"foo")
|
||||
self.assertEqual(self.sock.data, b"\x89\x03foo")
|
||||
|
||||
def test_pong_data(self):
|
||||
self.ws.pong(b'foo')
|
||||
self.assertEqual(self.sock.data, b'\x8a\x03foo')
|
||||
self.ws.pong(b"foo")
|
||||
self.assertEqual(self.sock.data, b"\x8a\x03foo")
|
||||
|
||||
|
||||
class HyBiEncodeDecodeTestCase(unittest.TestCase):
|
||||
def test_decode_hybi_text(self):
|
||||
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
|
||||
buf = b"\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58"
|
||||
ws = websocket.WebSocket()
|
||||
res = ws._decode_hybi(buf)
|
||||
|
||||
self.assertEqual(res['fin'], 1)
|
||||
self.assertEqual(res['opcode'], 0x1)
|
||||
self.assertEqual(res['masked'], True)
|
||||
self.assertEqual(res['length'], len(buf))
|
||||
self.assertEqual(res['payload'], b'Hello')
|
||||
self.assertEqual(res["fin"], 1)
|
||||
self.assertEqual(res["opcode"], 0x1)
|
||||
self.assertEqual(res["masked"], True)
|
||||
self.assertEqual(res["length"], len(buf))
|
||||
self.assertEqual(res["payload"], b"Hello")
|
||||
|
||||
def test_decode_hybi_binary(self):
|
||||
buf = b'\x82\x04\x01\x02\x03\x04'
|
||||
buf = b"\x82\x04\x01\x02\x03\x04"
|
||||
ws = websocket.WebSocket()
|
||||
res = ws._decode_hybi(buf)
|
||||
|
||||
self.assertEqual(res['fin'], 1)
|
||||
self.assertEqual(res['opcode'], 0x2)
|
||||
self.assertEqual(res['length'], len(buf))
|
||||
self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
|
||||
self.assertEqual(res["fin"], 1)
|
||||
self.assertEqual(res["opcode"], 0x2)
|
||||
self.assertEqual(res["length"], len(buf))
|
||||
self.assertEqual(res["payload"], b"\x01\x02\x03\x04")
|
||||
|
||||
def test_decode_hybi_extended_16bit_binary(self):
|
||||
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
|
||||
buf = b'\x82\x7e\x01\x04' + data
|
||||
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
|
||||
buf = b"\x82\x7e\x01\x04" + data
|
||||
ws = websocket.WebSocket()
|
||||
res = ws._decode_hybi(buf)
|
||||
|
||||
self.assertEqual(res['fin'], 1)
|
||||
self.assertEqual(res['opcode'], 0x2)
|
||||
self.assertEqual(res['length'], len(buf))
|
||||
self.assertEqual(res['payload'], data)
|
||||
self.assertEqual(res["fin"], 1)
|
||||
self.assertEqual(res["opcode"], 0x2)
|
||||
self.assertEqual(res["length"], len(buf))
|
||||
self.assertEqual(res["payload"], data)
|
||||
|
||||
def test_decode_hybi_extended_64bit_binary(self):
|
||||
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
|
||||
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
|
||||
data = b"\x01\x02\x03\x04" * 65 # len > 126 -- len == 260
|
||||
buf = b"\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04" + data
|
||||
ws = websocket.WebSocket()
|
||||
res = ws._decode_hybi(buf)
|
||||
|
||||
self.assertEqual(res['fin'], 1)
|
||||
self.assertEqual(res['opcode'], 0x2)
|
||||
self.assertEqual(res['length'], len(buf))
|
||||
self.assertEqual(res['payload'], data)
|
||||
self.assertEqual(res["fin"], 1)
|
||||
self.assertEqual(res["opcode"], 0x2)
|
||||
self.assertEqual(res["length"], len(buf))
|
||||
self.assertEqual(res["payload"], data)
|
||||
|
||||
def test_decode_hybi_multi(self):
|
||||
buf1 = b'\x01\x03\x48\x65\x6c'
|
||||
buf2 = b'\x80\x02\x6c\x6f'
|
||||
buf1 = b"\x01\x03\x48\x65\x6c"
|
||||
buf2 = b"\x80\x02\x6c\x6f"
|
||||
|
||||
ws = websocket.WebSocket()
|
||||
|
||||
res1 = ws._decode_hybi(buf1)
|
||||
self.assertEqual(res1['fin'], 0)
|
||||
self.assertEqual(res1['opcode'], 0x1)
|
||||
self.assertEqual(res1['length'], len(buf1))
|
||||
self.assertEqual(res1['payload'], b'Hel')
|
||||
self.assertEqual(res1["fin"], 0)
|
||||
self.assertEqual(res1["opcode"], 0x1)
|
||||
self.assertEqual(res1["length"], len(buf1))
|
||||
self.assertEqual(res1["payload"], b"Hel")
|
||||
|
||||
res2 = ws._decode_hybi(buf2)
|
||||
self.assertEqual(res2['fin'], 1)
|
||||
self.assertEqual(res2['opcode'], 0x0)
|
||||
self.assertEqual(res2['length'], len(buf2))
|
||||
self.assertEqual(res2['payload'], b'lo')
|
||||
self.assertEqual(res2["fin"], 1)
|
||||
self.assertEqual(res2["opcode"], 0x0)
|
||||
self.assertEqual(res2["length"], len(buf2))
|
||||
self.assertEqual(res2["payload"], b"lo")
|
||||
|
||||
def test_encode_hybi_basic(self):
|
||||
ws = websocket.WebSocket()
|
||||
res = ws._encode_hybi(0x1, b'Hello')
|
||||
expected = b'\x81\x05\x48\x65\x6c\x6c\x6f'
|
||||
res = ws._encode_hybi(0x1, b"Hello")
|
||||
expected = b"\x81\x05\x48\x65\x6c\x6c\x6f"
|
||||
|
||||
self.assertEqual(res, expected)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
""" Unit tests for websocketproxy """
|
||||
"""Unit tests for websocketproxy"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
|
@ -30,7 +30,7 @@ from websockify import auth_plugins
|
|||
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self, data=b''):
|
||||
def __init__(self, data=b""):
|
||||
self._data = data
|
||||
|
||||
def recv(self, amt, flags=None):
|
||||
|
|
@ -40,11 +40,11 @@ class FakeSocket:
|
|||
|
||||
return res
|
||||
|
||||
def makefile(self, mode='r', buffsize=None):
|
||||
if 'b' in mode:
|
||||
def makefile(self, mode="r", buffsize=None):
|
||||
if "b" in mode:
|
||||
return BytesIO(self._data)
|
||||
else:
|
||||
return StringIO(self._data.decode('latin_1'))
|
||||
return StringIO(self._data.decode("latin_1"))
|
||||
|
||||
|
||||
class FakeServer:
|
||||
|
|
@ -58,14 +58,16 @@ class FakeServer:
|
|||
self.ssl_target = None
|
||||
self.unix_target = None
|
||||
|
||||
|
||||
class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.handler = websocketproxy.ProxyRequestHandler(
|
||||
FakeSocket(), "127.0.0.1", FakeServer())
|
||||
FakeSocket(), "127.0.0.1", FakeServer()
|
||||
)
|
||||
self.handler.path = "https://localhost:6080/websockify?token=blah"
|
||||
self.handler.headers = {}
|
||||
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
||||
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
|
||||
|
||||
def tearDown(self):
|
||||
patch.stopall()
|
||||
|
|
@ -76,8 +78,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
|||
def lookup(self, token):
|
||||
return ("some host", "some port")
|
||||
|
||||
host, port = self.handler.get_target(
|
||||
TestPlugin(None))
|
||||
host, port = self.handler.get_target(TestPlugin(None))
|
||||
|
||||
self.assertEqual(host, "some host")
|
||||
self.assertEqual(port, "some port")
|
||||
|
|
@ -87,8 +88,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
|||
def lookup(self, token):
|
||||
return ("unix_socket", "/tmp/socket")
|
||||
|
||||
_, socket = self.handler.get_target(
|
||||
TestPlugin(None))
|
||||
_, socket = self.handler.get_target(TestPlugin(None))
|
||||
|
||||
self.assertEqual(socket, "/tmp/socket")
|
||||
|
||||
|
|
@ -100,11 +100,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
|||
with self.assertRaises(FakeServer.EClose):
|
||||
self.handler.get_target(TestPlugin(None))
|
||||
|
||||
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
|
||||
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
|
||||
def test_token_plugin(self):
|
||||
class TestPlugin(token_plugins.BasePlugin):
|
||||
def lookup(self, token):
|
||||
return (self.source + token).split(',')
|
||||
return (self.source + token).split(",")
|
||||
|
||||
self.handler.server.token_plugin = TestPlugin("somehost,")
|
||||
self.handler.validate_connection()
|
||||
|
|
@ -112,7 +112,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
|||
self.assertEqual(self.handler.server.target_host, "somehost")
|
||||
self.assertEqual(self.handler.server.target_port, "blah")
|
||||
|
||||
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
|
||||
@patch("websockify.websocketproxy.ProxyRequestHandler.send_auth_error", MagicMock())
|
||||
def test_auth_plugin(self):
|
||||
class TestPlugin(auth_plugins.BasePlugin):
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
|
|
@ -128,4 +128,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
|||
|
||||
self.handler.server.target_host = "someotherhost"
|
||||
self.handler.auth_connection()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""Unit tests for websocketserver"""
|
||||
|
||||
""" Unit tests for websocketserver """
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
|
@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase):
|
|||
|
||||
# Then
|
||||
req_obj.end_headers.assert_called_once_with()
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
""" Unit tests for websockifyserver """
|
||||
"""Unit tests for websockifyserver"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import logging
|
||||
|
|
@ -36,11 +37,11 @@ from websockify import websockifyserver
|
|||
|
||||
|
||||
def raise_oserror(*args, **kwargs):
|
||||
raise OSError('fake error')
|
||||
raise OSError("fake error")
|
||||
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self, data=b''):
|
||||
def __init__(self, data=b""):
|
||||
self._data = data
|
||||
|
||||
def recv(self, amt, flags=None):
|
||||
|
|
@ -50,19 +51,19 @@ class FakeSocket:
|
|||
|
||||
return res
|
||||
|
||||
def makefile(self, mode='r', buffsize=None):
|
||||
if 'b' in mode:
|
||||
def makefile(self, mode="r", buffsize=None):
|
||||
if "b" in mode:
|
||||
return BytesIO(self._data)
|
||||
else:
|
||||
return StringIO(self._data.decode('latin_1'))
|
||||
return StringIO(self._data.decode("latin_1"))
|
||||
|
||||
|
||||
class WebSockifyRequestHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
||||
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
|
||||
# Mock this out cause it screws tests up
|
||||
patch('os.chdir').start()
|
||||
patch("os.chdir").start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Called automatically after each test."""
|
||||
|
|
@ -70,31 +71,41 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
|
|||
os.rmdir(self.tmpdir)
|
||||
super().tearDown()
|
||||
|
||||
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
|
||||
**kwargs):
|
||||
web = kwargs.pop('web', self.tmpdir)
|
||||
def _get_server(
|
||||
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
|
||||
):
|
||||
web = kwargs.pop("web", self.tmpdir)
|
||||
return websockifyserver.WebSockifyServer(
|
||||
handler_class, listen_host='localhost',
|
||||
listen_port=80, key=self.tmpdir, web=web,
|
||||
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
|
||||
**kwargs)
|
||||
handler_class,
|
||||
listen_host="localhost",
|
||||
listen_port=80,
|
||||
key=self.tmpdir,
|
||||
web=web,
|
||||
record=self.tmpdir,
|
||||
daemon=False,
|
||||
ssl_only=0,
|
||||
idle_timeout=1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
|
||||
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
|
||||
def test_normal_get_with_only_upgrade_returns_error(self, send_error):
|
||||
server = self._get_server(web=None)
|
||||
handler = websockifyserver.WebSockifyRequestHandler(
|
||||
FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
|
||||
FakeSocket(b"GET /tmp.txt HTTP/1.1"), "127.0.0.1", server
|
||||
)
|
||||
|
||||
handler.do_GET()
|
||||
send_error.assert_called_with(405)
|
||||
|
||||
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
|
||||
@patch("websockify.websockifyserver.WebSockifyRequestHandler.send_error")
|
||||
def test_list_dir_with_file_only_returns_error(self, send_error):
|
||||
server = self._get_server(file_only=True)
|
||||
handler = websockifyserver.WebSockifyRequestHandler(
|
||||
FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server)
|
||||
FakeSocket(b"GET / HTTP/1.1"), "127.0.0.1", server
|
||||
)
|
||||
|
||||
handler.path = '/'
|
||||
handler.path = "/"
|
||||
handler.do_GET()
|
||||
send_error.assert_called_with(404)
|
||||
|
||||
|
|
@ -102,9 +113,9 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase):
|
|||
class WebSockifyServerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
||||
self.tmpdir = tempfile.mkdtemp("-websockify-tests")
|
||||
# Mock this out cause it screws tests up
|
||||
patch('os.chdir').start()
|
||||
patch("os.chdir").start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Called automatically after each test."""
|
||||
|
|
@ -112,32 +123,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
os.rmdir(self.tmpdir)
|
||||
super().tearDown()
|
||||
|
||||
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
|
||||
**kwargs):
|
||||
def _get_server(
|
||||
self, handler_class=websockifyserver.WebSockifyRequestHandler, **kwargs
|
||||
):
|
||||
return websockifyserver.WebSockifyServer(
|
||||
handler_class, listen_host='localhost',
|
||||
listen_port=80, key=self.tmpdir, web=self.tmpdir,
|
||||
record=self.tmpdir, **kwargs)
|
||||
handler_class,
|
||||
listen_host="localhost",
|
||||
listen_port=80,
|
||||
key=self.tmpdir,
|
||||
web=self.tmpdir,
|
||||
record=self.tmpdir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def test_daemonize_raises_error_while_closing_fds(self):
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
patch('os.fork').start().return_value = 0
|
||||
patch('signal.signal').start()
|
||||
patch('os.setsid').start()
|
||||
patch('os.close').start().side_effect = raise_oserror
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
||||
patch("os.fork").start().return_value = 0
|
||||
patch("signal.signal").start()
|
||||
patch("os.setsid").start()
|
||||
patch("os.close").start().side_effect = raise_oserror
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
|
||||
|
||||
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
|
||||
def raise_oserror_ebadf(fd):
|
||||
raise OSError(errno.EBADF, 'fake error')
|
||||
raise OSError(errno.EBADF, "fake error")
|
||||
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
patch('os.fork').start().return_value = 0
|
||||
patch('signal.signal').start()
|
||||
patch('os.setsid').start()
|
||||
patch('os.close').start().side_effect = raise_oserror_ebadf
|
||||
patch('os.open').start().side_effect = raise_oserror
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
||||
patch("os.fork").start().return_value = 0
|
||||
patch("signal.signal").start()
|
||||
patch("os.setsid").start()
|
||||
patch("os.close").start().side_effect = raise_oserror_ebadf
|
||||
patch("os.open").start().side_effect = raise_oserror
|
||||
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir="./")
|
||||
|
||||
def test_handshake_fails_on_not_ready(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
|
@ -145,23 +162,29 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([], [], [])
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
self.assertRaises(
|
||||
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
||||
FakeSocket(), '127.0.0.1')
|
||||
websockifyserver.WebSockifyServer.EClose,
|
||||
server.do_handshake,
|
||||
FakeSocket(),
|
||||
"127.0.0.1",
|
||||
)
|
||||
|
||||
def test_empty_handshake_fails(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
||||
sock = FakeSocket('')
|
||||
sock = FakeSocket("")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
self.assertRaises(
|
||||
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
websockifyserver.WebSockifyServer.EClose,
|
||||
server.do_handshake,
|
||||
sock,
|
||||
"127.0.0.1",
|
||||
)
|
||||
|
||||
def test_handshake_policy_request(self):
|
||||
# TODO(directxman12): implement
|
||||
|
|
@ -170,35 +193,39 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
def test_handshake_ssl_only_without_ssl_raises_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
||||
|
||||
sock = FakeSocket(b'some initial data')
|
||||
sock = FakeSocket(b"some initial data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
self.assertRaises(
|
||||
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
websockifyserver.WebSockifyServer.EClose,
|
||||
server.do_handshake,
|
||||
sock,
|
||||
"127.0.0.1",
|
||||
)
|
||||
|
||||
def test_do_handshake_no_ssl(self):
|
||||
class FakeHandler:
|
||||
CALLED = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).CALLED = True
|
||||
|
||||
FakeHandler.CALLED = False
|
||||
|
||||
server = self._get_server(
|
||||
handler_class=FakeHandler, daemon=True,
|
||||
ssl_only=0, idle_timeout=1)
|
||||
handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1
|
||||
)
|
||||
|
||||
sock = FakeSocket(b'some initial data')
|
||||
sock = FakeSocket(b"some initial data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
self.assertEqual(server.do_handshake(sock, "127.0.0.1"), sock)
|
||||
self.assertTrue(FakeHandler.CALLED, True)
|
||||
|
||||
def test_do_handshake_ssl(self):
|
||||
|
|
@ -210,18 +237,22 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_do_handshake_ssl_without_cert_raises_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
|
||||
cert='afdsfasdafdsafdsafdsafdas')
|
||||
server = self._get_server(
|
||||
daemon=True, ssl_only=0, idle_timeout=1, cert="afdsfasdafdsafdsafdsafdas"
|
||||
)
|
||||
|
||||
sock = FakeSocket(b"\x16some ssl data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
self.assertRaises(
|
||||
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
websockifyserver.WebSockifyServer.EClose,
|
||||
server.do_handshake,
|
||||
sock,
|
||||
"127.0.0.1",
|
||||
)
|
||||
|
||||
def test_do_handshake_ssl_error_eof_raises_close_error(self):
|
||||
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
||||
|
|
@ -234,58 +265,79 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
def fake_wrap_socket(*args, **kwargs):
|
||||
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
|
||||
|
||||
class fake_create_default_context():
|
||||
class fake_create_default_context:
|
||||
def __init__(self, purpose):
|
||||
self.verify_mode = None
|
||||
self.options = 0
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile, password):
|
||||
pass
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
pass
|
||||
|
||||
def load_verify_locations(self, cafile):
|
||||
pass
|
||||
|
||||
def wrap_socket(self, *args, **kwargs):
|
||||
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
patch(
|
||||
"ssl.create_default_context"
|
||||
).start().side_effect = fake_create_default_context
|
||||
self.assertRaises(
|
||||
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
||||
sock, '127.0.0.1')
|
||||
websockifyserver.WebSockifyServer.EClose,
|
||||
server.do_handshake,
|
||||
sock,
|
||||
"127.0.0.1",
|
||||
)
|
||||
|
||||
def test_do_handshake_ssl_sets_ciphers(self):
|
||||
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
|
||||
test_ciphers = "TEST-CIPHERS-1:TEST-CIPHER-2"
|
||||
|
||||
class FakeHandler:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
server = self._get_server(handler_class=FakeHandler, daemon=True,
|
||||
idle_timeout=1, ssl_ciphers=test_ciphers)
|
||||
server = self._get_server(
|
||||
handler_class=FakeHandler,
|
||||
daemon=True,
|
||||
idle_timeout=1,
|
||||
ssl_ciphers=test_ciphers,
|
||||
)
|
||||
sock = FakeSocket(b"\x16some ssl data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
return ([sock], [], [])
|
||||
|
||||
class fake_create_default_context():
|
||||
CIPHERS = ''
|
||||
class fake_create_default_context:
|
||||
CIPHERS = ""
|
||||
|
||||
def __init__(self, purpose):
|
||||
self.verify_mode = None
|
||||
self.options = 0
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile, password):
|
||||
pass
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
pass
|
||||
|
||||
def load_verify_locations(self, cafile):
|
||||
pass
|
||||
|
||||
def wrap_socket(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_ciphers(self, ciphers_to_set):
|
||||
fake_create_default_context.CIPHERS = ciphers_to_set
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
||||
server.do_handshake(sock, '127.0.0.1')
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
patch(
|
||||
"ssl.create_default_context"
|
||||
).start().side_effect = fake_create_default_context
|
||||
server.do_handshake(sock, "127.0.0.1")
|
||||
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
|
||||
|
||||
def test_do_handshake_ssl_sets_opions(self):
|
||||
|
|
@ -295,8 +347,12 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
server = self._get_server(handler_class=FakeHandler, daemon=True,
|
||||
idle_timeout=1, ssl_options=test_options)
|
||||
server = self._get_server(
|
||||
handler_class=FakeHandler,
|
||||
daemon=True,
|
||||
idle_timeout=1,
|
||||
ssl_options=test_options,
|
||||
)
|
||||
sock = FakeSocket(b"\x16some ssl data")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
|
|
@ -304,26 +360,36 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
|
||||
class fake_create_default_context:
|
||||
OPTIONS = 0
|
||||
|
||||
def __init__(self, purpose):
|
||||
self.verify_mode = None
|
||||
self._options = 0
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile, password):
|
||||
pass
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
pass
|
||||
|
||||
def load_verify_locations(self, cafile):
|
||||
pass
|
||||
|
||||
def wrap_socket(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_options(self):
|
||||
return self._options
|
||||
|
||||
def set_options(self, val):
|
||||
fake_create_default_context.OPTIONS = val
|
||||
|
||||
options = property(get_options, set_options)
|
||||
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
||||
server.do_handshake(sock, '127.0.0.1')
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
patch(
|
||||
"ssl.create_default_context"
|
||||
).start().side_effect = fake_create_default_context
|
||||
server.do_handshake(sock, "127.0.0.1")
|
||||
self.assertEqual(fake_create_default_context.OPTIONS, test_options)
|
||||
|
||||
def test_fallback_sigchld_handler(self):
|
||||
|
|
@ -332,38 +398,38 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
|
||||
def test_start_server_error(self):
|
||||
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
sock = server.socket("localhost")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
raise Exception("fake error")
|
||||
|
||||
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
||||
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
|
||||
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
server.start_server()
|
||||
|
||||
def test_start_server_keyboardinterrupt(self):
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
sock = server.socket("localhost")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
||||
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
|
||||
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
server.start_server()
|
||||
|
||||
def test_start_server_systemexit(self):
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost')
|
||||
sock = server.socket("localhost")
|
||||
|
||||
def fake_select(rlist, wlist, xlist, timeout=None):
|
||||
sys.exit()
|
||||
|
||||
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
||||
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
||||
patch('select.select').start().side_effect = fake_select
|
||||
patch("websockify.websockifyserver.WebSockifyServer.socket").start()
|
||||
patch("websockify.websockifyserver.WebSockifyServer.daemonize").start()
|
||||
patch("select.select").start().side_effect = fake_select
|
||||
server.start_server()
|
||||
|
||||
def test_socket_set_keepalive_options(self):
|
||||
|
|
@ -372,29 +438,37 @@ class WebSockifyServerTestCase(unittest.TestCase):
|
|||
keepintvl = 56
|
||||
|
||||
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
||||
sock = server.socket('localhost',
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
sock = server.socket(
|
||||
"localhost",
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl,
|
||||
)
|
||||
|
||||
if hasattr(socket, 'TCP_KEEPCNT'):
|
||||
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPCNT), keepcnt)
|
||||
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPIDLE), keepidle)
|
||||
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPINTVL), keepintvl)
|
||||
if hasattr(socket, "TCP_KEEPCNT"):
|
||||
self.assertEqual(
|
||||
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
|
||||
)
|
||||
self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle)
|
||||
self.assertEqual(
|
||||
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
|
||||
)
|
||||
|
||||
sock = server.socket('localhost',
|
||||
tcp_keepalive=False,
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl)
|
||||
sock = server.socket(
|
||||
"localhost",
|
||||
tcp_keepalive=False,
|
||||
tcp_keepcnt=keepcnt,
|
||||
tcp_keepidle=keepidle,
|
||||
tcp_keepintvl=keepintvl,
|
||||
)
|
||||
|
||||
if hasattr(socket, 'TCP_KEEPCNT'):
|
||||
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPCNT), keepcnt)
|
||||
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPIDLE), keepidle)
|
||||
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
||||
socket.TCP_KEEPINTVL), keepintvl)
|
||||
if hasattr(socket, "TCP_KEEPCNT"):
|
||||
self.assertNotEqual(
|
||||
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt
|
||||
)
|
||||
self.assertNotEqual(
|
||||
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE), keepidle
|
||||
)
|
||||
self.assertNotEqual(
|
||||
sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
run
|
||||
run
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import websockify
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
websockify.websocketproxy.websockify_init()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class BasePlugin():
|
||||
class BasePlugin:
|
||||
def __init__(self, src=None):
|
||||
self.source = src
|
||||
|
||||
|
|
@ -7,7 +7,9 @@ class BasePlugin():
|
|||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
|
||||
def __init__(
|
||||
self, log_msg=None, response_code=403, response_headers={}, response_msg=None
|
||||
):
|
||||
self.code = response_code
|
||||
self.headers = response_headers
|
||||
self.msg = response_msg
|
||||
|
|
@ -15,7 +17,7 @@ class AuthenticationError(Exception):
|
|||
if log_msg is None:
|
||||
log_msg = response_msg
|
||||
|
||||
super().__init__('%s %s' % (self.code, log_msg))
|
||||
super().__init__("%s %s" % (self.code, log_msg))
|
||||
|
||||
|
||||
class InvalidOriginError(AuthenticationError):
|
||||
|
|
@ -24,12 +26,13 @@ class InvalidOriginError(AuthenticationError):
|
|||
self.actual_origin = actual
|
||||
|
||||
super().__init__(
|
||||
response_msg='Invalid Origin',
|
||||
response_msg="Invalid Origin",
|
||||
log_msg="Invalid Origin Header: Expected one of "
|
||||
"%s, got '%s'" % (expected, actual))
|
||||
"%s, got '%s'" % (expected, actual),
|
||||
)
|
||||
|
||||
|
||||
class BasicHTTPAuth():
|
||||
class BasicHTTPAuth:
|
||||
"""Verifies Basic Auth headers. Specify src as username:password"""
|
||||
|
||||
def __init__(self, src=None):
|
||||
|
|
@ -37,9 +40,10 @@ class BasicHTTPAuth():
|
|||
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
import base64
|
||||
auth_header = headers.get('Authorization')
|
||||
|
||||
auth_header = headers.get("Authorization")
|
||||
if auth_header:
|
||||
if not auth_header.startswith('Basic '):
|
||||
if not auth_header.startswith("Basic "):
|
||||
self.auth_error()
|
||||
|
||||
try:
|
||||
|
|
@ -49,11 +53,11 @@ class BasicHTTPAuth():
|
|||
|
||||
try:
|
||||
# http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication
|
||||
user_pass_as_text = user_pass_raw.decode('ISO-8859-1')
|
||||
user_pass_as_text = user_pass_raw.decode("ISO-8859-1")
|
||||
except UnicodeDecodeError:
|
||||
self.auth_error()
|
||||
|
||||
user_pass = user_pass_as_text.split(':', 1)
|
||||
user_pass = user_pass_as_text.split(":", 1)
|
||||
if len(user_pass) != 2:
|
||||
self.auth_error()
|
||||
|
||||
|
|
@ -64,7 +68,7 @@ class BasicHTTPAuth():
|
|||
self.demand_auth()
|
||||
|
||||
def validate_creds(self, username, password):
|
||||
if '%s:%s' % (username, password) == self.src:
|
||||
if "%s:%s" % (username, password) == self.src:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
@ -73,10 +77,13 @@ class BasicHTTPAuth():
|
|||
raise AuthenticationError(response_code=403)
|
||||
|
||||
def demand_auth(self):
|
||||
raise AuthenticationError(response_code=401,
|
||||
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
|
||||
raise AuthenticationError(
|
||||
response_code=401,
|
||||
response_headers={"WWW-Authenticate": 'Basic realm="Websockify"'},
|
||||
)
|
||||
|
||||
class ExpectOrigin():
|
||||
|
||||
class ExpectOrigin:
|
||||
def __init__(self, src=None):
|
||||
if src is None:
|
||||
self.source = []
|
||||
|
|
@ -84,11 +91,12 @@ class ExpectOrigin():
|
|||
self.source = src.split()
|
||||
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
origin = headers.get('Origin', None)
|
||||
origin = headers.get("Origin", None)
|
||||
if origin is None or origin not in self.source:
|
||||
raise InvalidOriginError(expected=self.source, actual=origin)
|
||||
|
||||
class ClientCertCNAuth():
|
||||
|
||||
class ClientCertCNAuth:
|
||||
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""
|
||||
|
||||
def __init__(self, src=None):
|
||||
|
|
@ -98,5 +106,5 @@ class ClientCertCNAuth():
|
|||
self.source = src.split()
|
||||
|
||||
def authenticate(self, headers, target_host, target_port):
|
||||
if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source:
|
||||
if headers.get("SSL_CLIENT_S_DN_CN", None) not in self.source:
|
||||
raise AuthenticationError(response_code=403)
|
||||
|
|
|
|||
|
|
@ -7,23 +7,26 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
|
|||
as defined by RFC 5424.
|
||||
"""
|
||||
|
||||
_legacy_head_fmt = '<{pri}>{ident}[{pid}]: '
|
||||
_rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - '
|
||||
_legacy_head_fmt = "<{pri}>{ident}[{pid}]: "
|
||||
_rfc5424_head_fmt = "<{pri}>1 {timestamp} {hostname} {ident} {pid} - - "
|
||||
_head_fmt = _rfc5424_head_fmt
|
||||
_legacy = False
|
||||
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ'
|
||||
_timestamp_fmt = "%Y-%m-%dT%H:%M:%SZ"
|
||||
_max_hostname = 255
|
||||
_max_ident = 24 #safer for old daemons
|
||||
_max_ident = 24 # safer for old daemons
|
||||
_send_length = False
|
||||
_tail = '\n'
|
||||
|
||||
_tail = "\n"
|
||||
|
||||
ident = None
|
||||
|
||||
|
||||
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT),
|
||||
facility=handlers.SysLogHandler.LOG_USER,
|
||||
socktype=None, ident=None, legacy=False):
|
||||
def __init__(
|
||||
self,
|
||||
address=("localhost", handlers.SYSLOG_UDP_PORT),
|
||||
facility=handlers.SysLogHandler.LOG_USER,
|
||||
socktype=None,
|
||||
ident=None,
|
||||
legacy=False,
|
||||
):
|
||||
"""
|
||||
Initialize a handler.
|
||||
|
||||
|
|
@ -46,7 +49,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
|
|||
|
||||
super().__init__(address, facility, socktype)
|
||||
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a record.
|
||||
|
|
@ -57,46 +59,44 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
|
|||
|
||||
try:
|
||||
# Gather info.
|
||||
text = self.format(record).replace(self._tail, ' ')
|
||||
if not text: # nothing to log
|
||||
text = self.format(record).replace(self._tail, " ")
|
||||
if not text: # nothing to log
|
||||
return
|
||||
|
||||
pri = self.encodePriority(self.facility,
|
||||
self.mapPriority(record.levelname))
|
||||
pri = self.encodePriority(self.facility, self.mapPriority(record.levelname))
|
||||
|
||||
timestamp = time.strftime(self._timestamp_fmt, time.gmtime());
|
||||
|
||||
hostname = socket.gethostname()[:self._max_hostname]
|
||||
timestamp = time.strftime(self._timestamp_fmt, time.gmtime())
|
||||
hostname = socket.gethostname()[: self._max_hostname]
|
||||
|
||||
if self.ident:
|
||||
ident = self.ident[:self._max_ident]
|
||||
ident = self.ident[: self._max_ident]
|
||||
else:
|
||||
ident = ''
|
||||
ident = ""
|
||||
|
||||
pid = os.getpid() # shouldn't need truncation
|
||||
pid = os.getpid() # shouldn't need truncation
|
||||
|
||||
# Format the header.
|
||||
head = {
|
||||
'pri': pri,
|
||||
'timestamp': timestamp,
|
||||
'hostname': hostname,
|
||||
'ident': ident,
|
||||
'pid': pid,
|
||||
"pri": pri,
|
||||
"timestamp": timestamp,
|
||||
"hostname": hostname,
|
||||
"ident": ident,
|
||||
"pid": pid,
|
||||
}
|
||||
msg = self._head_fmt.format(**head).encode('ascii', 'ignore')
|
||||
msg = self._head_fmt.format(**head).encode("ascii", "ignore")
|
||||
|
||||
# Encode text as plain ASCII if possible, else use UTF-8 with BOM.
|
||||
try:
|
||||
msg += text.encode('ascii')
|
||||
msg += text.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
msg += text.encode('utf-8-sig')
|
||||
msg += text.encode("utf-8-sig")
|
||||
|
||||
# Add length or tail character, if necessary.
|
||||
if self.socktype != socket.SOCK_DGRAM:
|
||||
if self._send_length:
|
||||
msg = ('%d ' % len(msg)).encode('ascii') + msg
|
||||
msg = ("%d " % len(msg)).encode("ascii") + msg
|
||||
else:
|
||||
msg += self._tail.encode('ascii')
|
||||
msg += self._tail.encode("ascii")
|
||||
|
||||
# Send the message.
|
||||
if self.unixsocket:
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ logger = logging.getLogger(__name__)
|
|||
_SOURCE_SPLIT_REGEX = re.compile(
|
||||
r'(?<=^)"([^"]+)"(?=:|$)'
|
||||
r'|(?<=:)"([^"]+)"(?=:|$)'
|
||||
r'|(?<=^)([^:]*)(?=:|$)'
|
||||
r'|(?<=:)([^:]*)(?=:|$)',
|
||||
r"|(?<=^)([^:]*)(?=:|$)"
|
||||
r"|(?<=:)([^:]*)(?=:|$)",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ def parse_source_args(src):
|
|||
return [m[0] or m[1] or m[2] or m[3] for m in matches]
|
||||
|
||||
|
||||
class BasePlugin():
|
||||
class BasePlugin:
|
||||
def __init__(self, src):
|
||||
self.source = src
|
||||
|
||||
|
|
@ -44,8 +44,7 @@ class ReadOnlyTokenFile(BasePlugin):
|
|||
|
||||
def _load_targets(self):
|
||||
if os.path.isdir(self.source):
|
||||
cfg_files = [os.path.join(self.source, f) for
|
||||
f in os.listdir(self.source)]
|
||||
cfg_files = [os.path.join(self.source, f) for f in os.listdir(self.source)]
|
||||
else:
|
||||
cfg_files = [self.source]
|
||||
|
||||
|
|
@ -53,12 +52,14 @@ class ReadOnlyTokenFile(BasePlugin):
|
|||
index = 1
|
||||
for f in cfg_files:
|
||||
for line in [l.strip() for l in open(f).readlines()]:
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
try:
|
||||
tok, target = re.split(r':\s', line)
|
||||
self._targets[tok] = target.strip().rsplit(':', 1)
|
||||
tok, target = re.split(r":\s", line)
|
||||
self._targets[tok] = target.strip().rsplit(":", 1)
|
||||
except ValueError:
|
||||
logger.error("Syntax error in %s on line %d" % (self.source, index))
|
||||
logger.error(
|
||||
"Syntax error in %s on line %d" % (self.source, index)
|
||||
)
|
||||
index += 1
|
||||
|
||||
def lookup(self, token):
|
||||
|
|
@ -83,6 +84,7 @@ class TokenFile(ReadOnlyTokenFile):
|
|||
|
||||
return super().lookup(token)
|
||||
|
||||
|
||||
class TokenFileName(BasePlugin):
|
||||
# source is a directory
|
||||
# token is filename
|
||||
|
|
@ -91,12 +93,12 @@ class TokenFileName(BasePlugin):
|
|||
super().__init__(src)
|
||||
if not os.path.isdir(src):
|
||||
raise Exception("TokenFileName plugin requires a directory")
|
||||
|
||||
|
||||
def lookup(self, token):
|
||||
token = os.path.basename(token)
|
||||
path = os.path.join(self.source, token)
|
||||
if os.path.exists(path):
|
||||
return open(path).read().strip().split(':')
|
||||
return open(path).read().strip().split(":")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
@ -109,9 +111,9 @@ class BaseTokenAPI(BasePlugin):
|
|||
# in this file can be used w/o unnecessary dependencies
|
||||
|
||||
def process_result(self, resp):
|
||||
host, port = resp.text.split(':')
|
||||
port = port.encode('ascii','ignore')
|
||||
return [ host, port ]
|
||||
host, port = resp.text.split(":")
|
||||
port = port.encode("ascii", "ignore")
|
||||
return [host, port]
|
||||
|
||||
def lookup(self, token):
|
||||
import requests
|
||||
|
|
@ -130,7 +132,7 @@ class JSONTokenApi(BaseTokenAPI):
|
|||
|
||||
def process_result(self, resp):
|
||||
resp_json = resp.json()
|
||||
return (resp_json['host'], resp_json['port'])
|
||||
return (resp_json["host"], resp_json["port"])
|
||||
|
||||
|
||||
class JWTTokenApi(BasePlugin):
|
||||
|
|
@ -145,7 +147,7 @@ class JWTTokenApi(BasePlugin):
|
|||
key = jwk.JWK()
|
||||
|
||||
try:
|
||||
with open(self.source, 'rb') as key_file:
|
||||
with open(self.source, "rb") as key_file:
|
||||
key_data = key_file.read()
|
||||
except Exception as e:
|
||||
logger.error("Error loading key file: %s" % str(e))
|
||||
|
|
@ -155,39 +157,41 @@ class JWTTokenApi(BasePlugin):
|
|||
key.import_from_pem(key_data)
|
||||
except:
|
||||
try:
|
||||
key.import_key(k=key_data.decode('utf-8'),kty='oct')
|
||||
key.import_key(k=key_data.decode("utf-8"), kty="oct")
|
||||
except:
|
||||
logger.error('Failed to correctly parse key data!')
|
||||
logger.error("Failed to correctly parse key data!")
|
||||
return None
|
||||
|
||||
try:
|
||||
token = jwt.JWT(key=key, jwt=token)
|
||||
parsed_header = json.loads(token.header)
|
||||
|
||||
if 'enc' in parsed_header:
|
||||
if "enc" in parsed_header:
|
||||
# Token is encrypted, so we need to decrypt by passing the claims to a new instance
|
||||
token = jwt.JWT(key=key, jwt=token.claims)
|
||||
|
||||
parsed = json.loads(token.claims)
|
||||
|
||||
if 'nbf' in parsed:
|
||||
if "nbf" in parsed:
|
||||
# Not Before is present, so we need to check it
|
||||
if time.time() < parsed['nbf']:
|
||||
logger.warning('Token can not be used yet!')
|
||||
if time.time() < parsed["nbf"]:
|
||||
logger.warning("Token can not be used yet!")
|
||||
return None
|
||||
|
||||
if 'exp' in parsed:
|
||||
if "exp" in parsed:
|
||||
# Expiration time is present, so we need to check it
|
||||
if time.time() > parsed['exp']:
|
||||
logger.warning('Token has expired!')
|
||||
if time.time() > parsed["exp"]:
|
||||
logger.warning("Token has expired!")
|
||||
return None
|
||||
|
||||
return (parsed['host'], parsed['port'])
|
||||
return (parsed["host"], parsed["port"])
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse token: %s" % str(e))
|
||||
return None
|
||||
except ImportError:
|
||||
logger.error("package jwcrypto not found, are you sure you've installed it correctly?")
|
||||
logger.error(
|
||||
"package jwcrypto not found, are you sure you've installed it correctly?"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -251,6 +255,7 @@ class TokenRedis(BasePlugin):
|
|||
|
||||
pip install redis
|
||||
"""
|
||||
|
||||
def __init__(self, src):
|
||||
try:
|
||||
import redis
|
||||
|
|
@ -285,7 +290,9 @@ class TokenRedis(BasePlugin):
|
|||
if not self._password:
|
||||
self._password = None
|
||||
elif len(fields) == 5:
|
||||
self._server, self._port, self._db, self._password, self._namespace = fields
|
||||
self._server, self._port, self._db, self._password, self._namespace = (
|
||||
fields
|
||||
)
|
||||
if not self._port:
|
||||
self._port = 6379
|
||||
if not self._db:
|
||||
|
|
@ -301,24 +308,30 @@ class TokenRedis(BasePlugin):
|
|||
if self._namespace:
|
||||
self._namespace += ":"
|
||||
|
||||
logger.info("TokenRedis backend initialized (%s:%s)" %
|
||||
(self._server, self._port))
|
||||
logger.info(
|
||||
"TokenRedis backend initialized (%s:%s)" % (self._server, self._port)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("The provided --token-source='%s' is not in the "
|
||||
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
|
||||
src)
|
||||
logger.error(
|
||||
"The provided --token-source='%s' is not in the "
|
||||
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]"
|
||||
% src
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
def lookup(self, token):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
logger.error("package redis not found, are you sure you've installed them correctly?")
|
||||
logger.error(
|
||||
"package redis not found, are you sure you've installed them correctly?"
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
logger.info("resolving token '%s'" % token)
|
||||
client = redis.Redis(host=self._server, port=self._port,
|
||||
db=self._db, password=self._password)
|
||||
client = redis.Redis(
|
||||
host=self._server, port=self._port, db=self._db, password=self._password
|
||||
)
|
||||
stuff = client.get(self._namespace + token)
|
||||
if stuff is None:
|
||||
return None
|
||||
|
|
@ -330,14 +343,14 @@ class TokenRedis(BasePlugin):
|
|||
combo = json.loads(responseStr)
|
||||
host, port = combo["host"].split(":")
|
||||
except ValueError:
|
||||
logger.error("Unable to decode JSON token: %s" %
|
||||
responseStr)
|
||||
logger.error("Unable to decode JSON token: %s" % responseStr)
|
||||
return None
|
||||
except KeyError:
|
||||
logger.error("Unable to find 'host' key in JSON token: %s" %
|
||||
responseStr)
|
||||
logger.error(
|
||||
"Unable to find 'host' key in JSON token: %s" % responseStr
|
||||
)
|
||||
return None
|
||||
elif re.match(r'\S+:\S+', responseStr):
|
||||
elif re.match(r"\S+:\S+", responseStr):
|
||||
host, port = responseStr.split(":")
|
||||
else:
|
||||
logger.error("Unable to parse token: %s" % responseStr)
|
||||
|
|
@ -368,7 +381,7 @@ class UnixDomainSocketDirectory(BasePlugin):
|
|||
if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
|
||||
return None
|
||||
|
||||
return [ 'unix_socket', uds_path ]
|
||||
return ["unix_socket", uds_path]
|
||||
except Exception as e:
|
||||
logger.error("Error finding unix domain socket: %s" % str(e))
|
||||
return None
|
||||
logger.error("Error finding unix domain socket: %s" % str(e))
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
Python WebSocket library
|
||||
Copyright 2011 Joel Martin
|
||||
Copyright 2016 Pierre Ossman
|
||||
|
|
@ -10,7 +10,7 @@ Supports following protocol versions:
|
|||
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
|
||||
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
|
||||
- http://tools.ietf.org/html/rfc6455
|
||||
'''
|
||||
"""
|
||||
|
||||
import sys
|
||||
import array
|
||||
|
|
@ -28,14 +28,19 @@ try:
|
|||
import numpy
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("no 'numpy' module, HyBi protocol will be slower")
|
||||
numpy = None
|
||||
|
||||
|
||||
class WebSocketWantReadError(ssl.SSLWantReadError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketWantWriteError(ssl.SSLWantWriteError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocket:
|
||||
"""WebSocket protocol socket like class.
|
||||
|
||||
|
|
@ -73,11 +78,11 @@ class WebSocket:
|
|||
|
||||
self._state = "new"
|
||||
|
||||
self._partial_msg = b''
|
||||
self._partial_msg = b""
|
||||
|
||||
self._recv_buffer = b''
|
||||
self._recv_buffer = b""
|
||||
self._recv_queue = []
|
||||
self._send_buffer = b''
|
||||
self._send_buffer = b""
|
||||
|
||||
self._previous_sendmsg = None
|
||||
|
||||
|
|
@ -91,16 +96,22 @@ class WebSocket:
|
|||
|
||||
def __getattr__(self, name):
|
||||
# These methods are just redirected to the underlying socket
|
||||
if name in ["fileno",
|
||||
"getpeername", "getsockname",
|
||||
"getsockopt", "setsockopt",
|
||||
"gettimeout", "settimeout",
|
||||
"setblocking"]:
|
||||
if name in [
|
||||
"fileno",
|
||||
"getpeername",
|
||||
"getsockname",
|
||||
"getsockopt",
|
||||
"setsockopt",
|
||||
"gettimeout",
|
||||
"settimeout",
|
||||
"setblocking",
|
||||
]:
|
||||
assert self.socket is not None
|
||||
return getattr(self.socket, name)
|
||||
else:
|
||||
raise AttributeError("%s instance has no attribute '%s'" %
|
||||
(self.__class__.__name__, name))
|
||||
raise AttributeError(
|
||||
"%s instance has no attribute '%s'" % (self.__class__.__name__, name)
|
||||
)
|
||||
|
||||
def connect(self, uri, origin=None, protocols=[]):
|
||||
"""Establishes a new connection to a WebSocket server.
|
||||
|
|
@ -118,8 +129,7 @@ class WebSocket:
|
|||
connect() must retain the same arguments.
|
||||
"""
|
||||
|
||||
self.client = True;
|
||||
|
||||
self.client = True
|
||||
uri = urlparse(uri)
|
||||
|
||||
port = uri.port
|
||||
|
|
@ -140,8 +150,9 @@ class WebSocket:
|
|||
|
||||
if uri.scheme in ("wss", "https"):
|
||||
context = ssl.create_default_context()
|
||||
self.socket = context.wrap_socket(self.socket,
|
||||
server_hostname=uri.hostname)
|
||||
self.socket = context.wrap_socket(
|
||||
self.socket, server_hostname=uri.hostname
|
||||
)
|
||||
self._state = "ssl_handshake"
|
||||
else:
|
||||
self._state = "headers"
|
||||
|
|
@ -151,7 +162,7 @@ class WebSocket:
|
|||
self._state = "headers"
|
||||
|
||||
if self._state == "headers":
|
||||
self._key = ''
|
||||
self._key = ""
|
||||
for i in range(16):
|
||||
self._key += chr(random.randrange(256))
|
||||
self._key = b64encode(self._key.encode("latin-1")).decode("ascii")
|
||||
|
|
@ -184,10 +195,10 @@ class WebSocket:
|
|||
if not self._recv():
|
||||
raise Exception("Socket closed unexpectedly")
|
||||
|
||||
if self._recv_buffer.find(b'\r\n\r\n') == -1:
|
||||
if self._recv_buffer.find(b"\r\n\r\n") == -1:
|
||||
raise WebSocketWantReadError
|
||||
|
||||
(request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1)
|
||||
(request, self._recv_buffer) = self._recv_buffer.split(b"\r\n", 1)
|
||||
request = request.decode("latin-1")
|
||||
|
||||
words = request.split()
|
||||
|
|
@ -196,17 +207,17 @@ class WebSocket:
|
|||
if words[1] != "101":
|
||||
raise Exception("WebSocket request denied: %s" % " ".join(words[1:]))
|
||||
|
||||
(headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1)
|
||||
headers = headers.decode('latin-1') + '\r\n'
|
||||
(headers, self._recv_buffer) = self._recv_buffer.split(b"\r\n\r\n", 1)
|
||||
headers = headers.decode("latin-1") + "\r\n"
|
||||
headers = email.message_from_string(headers)
|
||||
|
||||
if headers.get("Upgrade", "").lower() != "websocket":
|
||||
print(type(headers))
|
||||
raise Exception("Missing or incorrect upgrade header")
|
||||
|
||||
accept = headers.get('Sec-WebSocket-Accept')
|
||||
accept = headers.get("Sec-WebSocket-Accept")
|
||||
if accept is None:
|
||||
raise Exception("Missing Sec-WebSocket-Accept header");
|
||||
raise Exception("Missing Sec-WebSocket-Accept header")
|
||||
|
||||
expected = sha1((self._key + self.GUID).encode("ascii")).digest()
|
||||
expected = b64encode(expected).decode("ascii")
|
||||
|
|
@ -214,9 +225,9 @@ class WebSocket:
|
|||
del self._key
|
||||
|
||||
if accept != expected:
|
||||
raise Exception("Invalid Sec-WebSocket-Accept header");
|
||||
raise Exception("Invalid Sec-WebSocket-Accept header")
|
||||
|
||||
self.protocol = headers.get('Sec-WebSocket-Protocol')
|
||||
self.protocol = headers.get("Sec-WebSocket-Protocol")
|
||||
if len(protocols) == 0:
|
||||
if self.protocol is not None:
|
||||
raise Exception("Unexpected Sec-WebSocket-Protocol header")
|
||||
|
|
@ -256,34 +267,34 @@ class WebSocket:
|
|||
if headers.get("upgrade", "").lower() != "websocket":
|
||||
raise Exception("Missing or incorrect upgrade header")
|
||||
|
||||
ver = headers.get('Sec-WebSocket-Version')
|
||||
ver = headers.get("Sec-WebSocket-Version")
|
||||
if ver is None:
|
||||
raise Exception("Missing Sec-WebSocket-Version header");
|
||||
raise Exception("Missing Sec-WebSocket-Version header")
|
||||
|
||||
# HyBi-07 report version 7
|
||||
# HyBi-08 - HyBi-12 report version 8
|
||||
# HyBi-13 reports version 13
|
||||
if ver in ['7', '8', '13']:
|
||||
if ver in ["7", "8", "13"]:
|
||||
self.version = "hybi-%02d" % int(ver)
|
||||
else:
|
||||
raise Exception("Unsupported protocol version %s" % ver)
|
||||
|
||||
key = headers.get('Sec-WebSocket-Key')
|
||||
key = headers.get("Sec-WebSocket-Key")
|
||||
if key is None:
|
||||
raise Exception("Missing Sec-WebSocket-Key header");
|
||||
raise Exception("Missing Sec-WebSocket-Key header")
|
||||
|
||||
# Generate the hash value for the accept header
|
||||
accept = sha1((key + self.GUID).encode("ascii")).digest()
|
||||
accept = b64encode(accept).decode("ascii")
|
||||
|
||||
self.protocol = ''
|
||||
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
|
||||
self.protocol = ""
|
||||
protocols = headers.get("Sec-WebSocket-Protocol", "").split(",")
|
||||
if protocols:
|
||||
self.protocol = self.select_subprotocol(protocols)
|
||||
# We are required to choose one of the protocols
|
||||
# presented by the client
|
||||
if self.protocol not in protocols:
|
||||
raise Exception('Invalid protocol selected')
|
||||
raise Exception("Invalid protocol selected")
|
||||
|
||||
self.send_response(101, "Switching Protocols")
|
||||
self.send_header("Upgrade", "websocket")
|
||||
|
|
@ -461,7 +472,7 @@ class WebSocket:
|
|||
def send_request(self, type, path):
|
||||
self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path))
|
||||
|
||||
def ping(self, data=b''):
|
||||
def ping(self, data=b""):
|
||||
"""Write a ping message to the WebSocket
|
||||
|
||||
WebSocketWantWriteError can be raised if there is insufficient
|
||||
|
|
@ -486,7 +497,7 @@ class WebSocket:
|
|||
self._previous_sendmsg = data
|
||||
raise
|
||||
|
||||
def pong(self, data=b''):
|
||||
def pong(self, data=b""):
|
||||
"""Write a pong message to the WebSocket
|
||||
|
||||
WebSocketWantWriteError can be raised if there is insufficient
|
||||
|
|
@ -540,7 +551,7 @@ class WebSocket:
|
|||
|
||||
self._sent_close = True
|
||||
|
||||
msg = b''
|
||||
msg = b""
|
||||
if code is not None:
|
||||
msg += struct.pack(">H", code)
|
||||
if reason is not None:
|
||||
|
|
@ -602,7 +613,7 @@ class WebSocket:
|
|||
frame = self._decode_hybi(self._recv_buffer)
|
||||
if frame is None:
|
||||
break
|
||||
self._recv_buffer = self._recv_buffer[frame['length']:]
|
||||
self._recv_buffer = self._recv_buffer[frame["length"] :]
|
||||
self._recv_queue.append(frame)
|
||||
|
||||
return True
|
||||
|
|
@ -612,29 +623,39 @@ class WebSocket:
|
|||
while self._recv_queue:
|
||||
frame = self._recv_queue.pop(0)
|
||||
|
||||
if not self.client and not frame['masked']:
|
||||
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked")
|
||||
if not self.client and not frame["masked"]:
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked"
|
||||
)
|
||||
continue
|
||||
if self.client and frame['masked']:
|
||||
if self.client and frame["masked"]:
|
||||
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked")
|
||||
continue
|
||||
|
||||
if frame["opcode"] == 0x0:
|
||||
if not self._partial_msg:
|
||||
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR,
|
||||
1002,
|
||||
"Procotol error: Unexpected continuation frame",
|
||||
)
|
||||
continue
|
||||
|
||||
self._partial_msg += frame["payload"]
|
||||
|
||||
if frame["fin"]:
|
||||
msg = self._partial_msg
|
||||
self._partial_msg = b''
|
||||
self._partial_msg = b""
|
||||
return msg
|
||||
elif frame["opcode"] == 0x1:
|
||||
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported"
|
||||
)
|
||||
elif frame["opcode"] == 0x2:
|
||||
if self._partial_msg:
|
||||
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame"
|
||||
)
|
||||
continue
|
||||
|
||||
if frame["fin"]:
|
||||
|
|
@ -652,7 +673,9 @@ class WebSocket:
|
|||
return None
|
||||
|
||||
if not frame["fin"]:
|
||||
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close"
|
||||
)
|
||||
continue
|
||||
|
||||
code = None
|
||||
|
|
@ -664,7 +687,11 @@ class WebSocket:
|
|||
try:
|
||||
reason = reason.decode("UTF-8")
|
||||
except UnicodeDecodeError:
|
||||
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR,
|
||||
1002,
|
||||
"Procotol error: Invalid UTF-8 in close",
|
||||
)
|
||||
continue
|
||||
|
||||
if code is None:
|
||||
|
|
@ -679,18 +706,26 @@ class WebSocket:
|
|||
return None
|
||||
elif frame["opcode"] == 0x9:
|
||||
if not frame["fin"]:
|
||||
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping"
|
||||
)
|
||||
continue
|
||||
|
||||
self.handle_ping(frame["payload"])
|
||||
elif frame["opcode"] == 0xA:
|
||||
if not frame["fin"]:
|
||||
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong")
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong"
|
||||
)
|
||||
continue
|
||||
|
||||
self.handle_pong(frame["payload"])
|
||||
else:
|
||||
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"])
|
||||
self.shutdown(
|
||||
socket.SHUT_RDWR,
|
||||
1003,
|
||||
"Unsupported: Unknown opcode 0x%02x" % frame["opcode"],
|
||||
)
|
||||
|
||||
raise WebSocketWantReadError
|
||||
|
||||
|
|
@ -731,7 +766,7 @@ class WebSocket:
|
|||
def _sendmsg(self, opcode, msg):
|
||||
# Sends a standard data message
|
||||
if self.client:
|
||||
mask = b''
|
||||
mask = b""
|
||||
for i in range(4):
|
||||
mask += random.randrange(256).to_bytes()
|
||||
frame = self._encode_hybi(opcode, msg, mask)
|
||||
|
|
@ -755,35 +790,36 @@ class WebSocket:
|
|||
plen = len(buf)
|
||||
pstart = 0
|
||||
pend = plen
|
||||
b = c = b''
|
||||
b = c = b""
|
||||
if plen >= 4:
|
||||
dtype=numpy.dtype('<u4')
|
||||
if sys.byteorder == 'big':
|
||||
dtype = dtype.newbyteorder('>')
|
||||
dtype = numpy.dtype("<u4")
|
||||
if sys.byteorder == "big":
|
||||
dtype = dtype.newbyteorder(">")
|
||||
mask = numpy.frombuffer(mask, dtype, count=1)
|
||||
data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
|
||||
#b = numpy.bitwise_xor(data, mask).data
|
||||
# b = numpy.bitwise_xor(data, mask).data
|
||||
b = numpy.bitwise_xor(data, mask).tobytes()
|
||||
|
||||
if plen % 4:
|
||||
dtype=numpy.dtype('B')
|
||||
if sys.byteorder == 'big':
|
||||
dtype = dtype.newbyteorder('>')
|
||||
dtype = numpy.dtype("B")
|
||||
if sys.byteorder == "big":
|
||||
dtype = dtype.newbyteorder(">")
|
||||
mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
|
||||
data = numpy.frombuffer(buf, dtype,
|
||||
offset=plen - (plen % 4), count=(plen % 4))
|
||||
data = numpy.frombuffer(
|
||||
buf, dtype, offset=plen - (plen % 4), count=(plen % 4)
|
||||
)
|
||||
c = numpy.bitwise_xor(data, mask).tobytes()
|
||||
return b + c
|
||||
else:
|
||||
# Slower fallback
|
||||
data = array.array('B')
|
||||
data = array.array("B")
|
||||
data.frombytes(buf)
|
||||
for i in range(len(data)):
|
||||
data[i] ^= mask[i % 4]
|
||||
return data.tobytes()
|
||||
|
||||
def _encode_hybi(self, opcode, buf, mask_key=None, fin=True):
|
||||
""" Encode a HyBi style WebSocket frame.
|
||||
"""Encode a HyBi style WebSocket frame.
|
||||
Optional opcode:
|
||||
0x0 - continuation
|
||||
0x1 - text frame
|
||||
|
|
@ -793,7 +829,7 @@ class WebSocket:
|
|||
0xA - pong
|
||||
"""
|
||||
|
||||
b1 = opcode & 0x0f
|
||||
b1 = opcode & 0x0F
|
||||
if fin:
|
||||
b1 |= 0x80
|
||||
|
||||
|
|
@ -804,11 +840,11 @@ class WebSocket:
|
|||
|
||||
payload_len = len(buf)
|
||||
if payload_len <= 125:
|
||||
header = struct.pack('>BB', b1, payload_len | mask_bit)
|
||||
header = struct.pack(">BB", b1, payload_len | mask_bit)
|
||||
elif payload_len > 125 and payload_len < 65536:
|
||||
header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len)
|
||||
header = struct.pack(">BBH", b1, 126 | mask_bit, payload_len)
|
||||
elif payload_len >= 65536:
|
||||
header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len)
|
||||
header = struct.pack(">BBQ", b1, 127 | mask_bit, payload_len)
|
||||
|
||||
if mask_key is not None:
|
||||
return header + mask_key + buf
|
||||
|
|
@ -816,7 +852,7 @@ class WebSocket:
|
|||
return header + buf
|
||||
|
||||
def _decode_hybi(self, buf):
|
||||
""" Decode HyBi style WebSocket packets.
|
||||
"""Decode HyBi style WebSocket packets.
|
||||
Returns:
|
||||
{'fin' : boolean,
|
||||
'opcode' : number,
|
||||
|
|
@ -825,11 +861,7 @@ class WebSocket:
|
|||
'payload' : decoded_buffer}
|
||||
"""
|
||||
|
||||
f = {'fin' : 0,
|
||||
'opcode' : 0,
|
||||
'masked' : False,
|
||||
'length' : 0,
|
||||
'payload' : None}
|
||||
f = {"fin": 0, "opcode": 0, "masked": False, "length": 0, "payload": None}
|
||||
|
||||
blen = len(buf)
|
||||
hlen = 2
|
||||
|
|
@ -838,39 +870,38 @@ class WebSocket:
|
|||
return None
|
||||
|
||||
b1, b2 = struct.unpack(">BB", buf[:2])
|
||||
f['opcode'] = b1 & 0x0f
|
||||
f['fin'] = not not (b1 & 0x80)
|
||||
f['masked'] = not not (b2 & 0x80)
|
||||
f["opcode"] = b1 & 0x0F
|
||||
f["fin"] = not not (b1 & 0x80)
|
||||
f["masked"] = not not (b2 & 0x80)
|
||||
|
||||
if f['masked']:
|
||||
if f["masked"]:
|
||||
hlen += 4
|
||||
if blen < hlen:
|
||||
return None
|
||||
|
||||
length = b2 & 0x7f
|
||||
length = b2 & 0x7F
|
||||
|
||||
if length == 126:
|
||||
hlen += 2
|
||||
if blen < hlen:
|
||||
return None
|
||||
length, = struct.unpack('>H', buf[2:4])
|
||||
(length,) = struct.unpack(">H", buf[2:4])
|
||||
elif length == 127:
|
||||
hlen += 8
|
||||
if blen < hlen:
|
||||
return None
|
||||
length, = struct.unpack('>Q', buf[2:10])
|
||||
(length,) = struct.unpack(">Q", buf[2:10])
|
||||
|
||||
f['length'] = hlen + length
|
||||
f["length"] = hlen + length
|
||||
|
||||
if blen < f['length']:
|
||||
if blen < f["length"]:
|
||||
return None
|
||||
|
||||
if f['masked']:
|
||||
if f["masked"]:
|
||||
# unmask payload
|
||||
mask_key = buf[hlen-4:hlen]
|
||||
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key)
|
||||
mask_key = buf[hlen - 4 : hlen]
|
||||
f["payload"] = self._unmask(buf[hlen : (hlen + length)], mask_key)
|
||||
else:
|
||||
f['payload'] = buf[hlen:(hlen+length)]
|
||||
f["payload"] = buf[hlen : (hlen + length)]
|
||||
|
||||
return f
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
A WebSocket to TCP socket proxy with support for "wss://" encryption.
|
||||
Copyright 2011 Joel Martin
|
||||
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
|
||||
|
|
@ -9,7 +9,7 @@ You can make a cert/key with openssl using:
|
|||
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
|
||||
as taken from http://docs.python.org/dev/library/ssl.html#certificates
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
|
||||
from socketserver import ThreadingMixIn
|
||||
|
|
@ -20,8 +20,8 @@ from websockify import websockifyserver
|
|||
from websockify import auth_plugins as auth
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
|
||||
|
||||
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
|
||||
buffer_size = 65536
|
||||
|
||||
traffic_legend = """
|
||||
|
|
@ -38,7 +38,7 @@ Traffic Legend:
|
|||
|
||||
def send_auth_error(self, ex):
|
||||
self.send_response(ex.code, ex.msg)
|
||||
self.send_header('Content-Type', 'text/html')
|
||||
self.send_header("Content-Type", "text/html")
|
||||
for name, val in ex.headers.items():
|
||||
self.send_header(name, val)
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ Traffic Legend:
|
|||
return
|
||||
|
||||
host, port = self.get_target(self.server.token_plugin)
|
||||
if host == 'unix_socket':
|
||||
if host == "unix_socket":
|
||||
self.server.unix_target = port
|
||||
|
||||
else:
|
||||
|
|
@ -62,7 +62,7 @@ Traffic Legend:
|
|||
|
||||
# clear out any existing SSL_ headers that the client might
|
||||
# have maliciously set
|
||||
ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ]
|
||||
ssl_headers = [h for h in self.headers if h.startswith("SSL_")]
|
||||
for h in ssl_headers:
|
||||
del self.headers[h]
|
||||
|
||||
|
|
@ -70,19 +70,21 @@ Traffic Legend:
|
|||
# get client certificate data
|
||||
client_cert_data = self.request.getpeercert()
|
||||
# extract subject information
|
||||
client_cert_subject = client_cert_data['subject']
|
||||
client_cert_subject = client_cert_data["subject"]
|
||||
# flatten data structure
|
||||
client_cert_subject = dict([x[0] for x in client_cert_subject])
|
||||
# add common name to headers (apache +StdEnvVars style)
|
||||
self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName']
|
||||
self.headers["SSL_CLIENT_S_DN_CN"] = client_cert_subject["commonName"]
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# not a SSL connection or client presented no certificate with valid data
|
||||
pass
|
||||
|
||||
try:
|
||||
self.server.auth_plugin.authenticate(
|
||||
headers=self.headers, target_host=self.server.target_host,
|
||||
target_port=self.server.target_port)
|
||||
headers=self.headers,
|
||||
target_host=self.server.target_host,
|
||||
target_port=self.server.target_port,
|
||||
)
|
||||
except auth.AuthenticationError:
|
||||
ex = sys.exc_info()[1]
|
||||
self.send_auth_error(ex)
|
||||
|
|
@ -96,26 +98,37 @@ Traffic Legend:
|
|||
|
||||
# Connect to the target
|
||||
if self.server.wrap_cmd:
|
||||
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
|
||||
msg = "connecting to command: '%s' (port %s)" % (
|
||||
" ".join(self.server.wrap_cmd),
|
||||
self.server.target_port,
|
||||
)
|
||||
elif self.server.unix_target:
|
||||
msg = "connecting to unix socket: %s" % self.server.unix_target
|
||||
else:
|
||||
msg = "connecting to: %s:%s" % (
|
||||
self.server.target_host, self.server.target_port)
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
)
|
||||
|
||||
if self.server.ssl_target:
|
||||
msg += " (using SSL)"
|
||||
self.log_message(msg)
|
||||
|
||||
try:
|
||||
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host,
|
||||
self.server.target_port,
|
||||
connect=True,
|
||||
use_ssl=self.server.ssl_target,
|
||||
unix_socket=self.server.unix_target)
|
||||
tsock = websockifyserver.WebSockifyServer.socket(
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
connect=True,
|
||||
use_ssl=self.server.ssl_target,
|
||||
unix_socket=self.server.unix_target,
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_message("Failed to connect to %s:%s: %s",
|
||||
self.server.target_host, self.server.target_port, e)
|
||||
self.log_message(
|
||||
"Failed to connect to %s:%s: %s",
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
e,
|
||||
)
|
||||
raise self.CClose(1011, "Failed to connect to downstream server")
|
||||
|
||||
# Option unavailable when listening to unix socket
|
||||
|
|
@ -134,8 +147,11 @@ Traffic Legend:
|
|||
tsock.shutdown(socket.SHUT_RDWR)
|
||||
tsock.close()
|
||||
if self.verbose:
|
||||
self.log_message("%s:%s: Closed target",
|
||||
self.server.target_host, self.server.target_port)
|
||||
self.log_message(
|
||||
"%s:%s: Closed target",
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
)
|
||||
|
||||
def get_target(self, target_plugin):
|
||||
"""
|
||||
|
|
@ -149,20 +165,20 @@ Traffic Legend:
|
|||
|
||||
if self.host_token:
|
||||
# Use hostname as token
|
||||
token = self.headers.get('Host')
|
||||
token = self.headers.get("Host")
|
||||
|
||||
# Remove port from hostname, as it'll always be the one where
|
||||
# websockify listens (unless something between the client and
|
||||
# websockify is redirecting traffic, but that's beside the point)
|
||||
if token:
|
||||
token = token.partition(':')[0]
|
||||
token = token.partition(":")[0]
|
||||
|
||||
else:
|
||||
# Extract the token parameter from url
|
||||
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
|
||||
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
|
||||
|
||||
if 'token' in args and len(args['token']):
|
||||
token = args['token'][0].rstrip('\n')
|
||||
if "token" in args and len(args["token"]):
|
||||
token = args["token"][0].rstrip("\n")
|
||||
else:
|
||||
token = None
|
||||
|
||||
|
|
@ -200,13 +216,15 @@ Traffic Legend:
|
|||
self.heartbeat = now + self.server.heartbeat
|
||||
self.send_ping()
|
||||
|
||||
if tqueue: wlist.append(target)
|
||||
if cqueue or c_pend: wlist.append(self.request)
|
||||
if tqueue:
|
||||
wlist.append(target)
|
||||
if cqueue or c_pend:
|
||||
wlist.append(self.request)
|
||||
try:
|
||||
ins, outs, excepts = select.select(rlist, wlist, [], 1)
|
||||
except OSError:
|
||||
exc = sys.exc_info()[1]
|
||||
if hasattr(exc, 'errno'):
|
||||
if hasattr(exc, "errno"):
|
||||
err = exc.errno
|
||||
else:
|
||||
err = exc[0]
|
||||
|
|
@ -216,7 +234,8 @@ Traffic Legend:
|
|||
else:
|
||||
continue
|
||||
|
||||
if excepts: raise Exception("Socket exception")
|
||||
if excepts:
|
||||
raise Exception("Socket exception")
|
||||
|
||||
if self.request in outs:
|
||||
# Send queued target data to the client
|
||||
|
|
@ -230,8 +249,7 @@ Traffic Legend:
|
|||
tqueue.extend(bufs)
|
||||
|
||||
if closed:
|
||||
|
||||
while (len(tqueue) != 0):
|
||||
while len(tqueue) != 0:
|
||||
# Send queued client data to the target
|
||||
dat = tqueue.pop(0)
|
||||
sent = target.send(dat)
|
||||
|
|
@ -244,10 +262,12 @@ Traffic Legend:
|
|||
|
||||
# TODO: What about blocking on client socket?
|
||||
if self.verbose:
|
||||
self.log_message("%s:%s: Client closed connection",
|
||||
self.server.target_host, self.server.target_port)
|
||||
raise self.CClose(closed['code'], closed['reason'])
|
||||
|
||||
self.log_message(
|
||||
"%s:%s: Client closed connection",
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
)
|
||||
raise self.CClose(closed["code"], closed["reason"])
|
||||
|
||||
if target in outs:
|
||||
# Send queued client data to the target
|
||||
|
|
@ -260,29 +280,31 @@ Traffic Legend:
|
|||
tqueue.insert(0, dat[sent:])
|
||||
self.print_traffic(".>")
|
||||
|
||||
|
||||
if target in ins:
|
||||
# Receive target data, encode it and queue for client
|
||||
buf = target.recv(self.buffer_size)
|
||||
if len(buf) == 0:
|
||||
|
||||
# Target socket closed, flushing queues and closing client-side websocket
|
||||
# Send queued target data to the client
|
||||
if len(cqueue) != 0:
|
||||
c_pend = True
|
||||
while(c_pend):
|
||||
while c_pend:
|
||||
c_pend = self.send_frames(cqueue)
|
||||
|
||||
cqueue = []
|
||||
|
||||
if self.verbose:
|
||||
self.log_message("%s:%s: Target closed connection",
|
||||
self.server.target_host, self.server.target_port)
|
||||
self.log_message(
|
||||
"%s:%s: Target closed connection",
|
||||
self.server.target_host,
|
||||
self.server.target_port,
|
||||
)
|
||||
raise self.CClose(1000, "Target closed")
|
||||
|
||||
cqueue.append(buf)
|
||||
self.print_traffic("{")
|
||||
|
||||
|
||||
class WebSocketProxy(websockifyserver.WebSockifyServer):
|
||||
"""
|
||||
Proxy traffic to and from a WebSockets client to a normal TCP
|
||||
|
|
@ -293,27 +315,29 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
|||
|
||||
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
|
||||
# Save off proxy specific options
|
||||
self.target_host = kwargs.pop('target_host', None)
|
||||
self.target_port = kwargs.pop('target_port', None)
|
||||
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
|
||||
self.wrap_mode = kwargs.pop('wrap_mode', None)
|
||||
self.unix_target = kwargs.pop('unix_target', None)
|
||||
self.ssl_target = kwargs.pop('ssl_target', None)
|
||||
self.heartbeat = kwargs.pop('heartbeat', None)
|
||||
self.target_host = kwargs.pop("target_host", None)
|
||||
self.target_port = kwargs.pop("target_port", None)
|
||||
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
|
||||
self.wrap_mode = kwargs.pop("wrap_mode", None)
|
||||
self.unix_target = kwargs.pop("unix_target", None)
|
||||
self.ssl_target = kwargs.pop("ssl_target", None)
|
||||
self.heartbeat = kwargs.pop("heartbeat", None)
|
||||
|
||||
self.token_plugin = kwargs.pop('token_plugin', None)
|
||||
self.host_token = kwargs.pop('host_token', None)
|
||||
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
||||
self.token_plugin = kwargs.pop("token_plugin", None)
|
||||
self.host_token = kwargs.pop("host_token", None)
|
||||
self.auth_plugin = kwargs.pop("auth_plugin", None)
|
||||
|
||||
# Last 3 timestamps command was run
|
||||
self.wrap_times = [0, 0, 0]
|
||||
self.wrap_times = [0, 0, 0]
|
||||
|
||||
if self.wrap_cmd:
|
||||
wsdir = os.path.dirname(sys.argv[0])
|
||||
rebinder_path = [os.path.join(wsdir, "..", "lib"),
|
||||
os.path.join(wsdir, "..", "lib", "websockify"),
|
||||
os.path.join(wsdir, ".."),
|
||||
wsdir]
|
||||
rebinder_path = [
|
||||
os.path.join(wsdir, "..", "lib"),
|
||||
os.path.join(wsdir, "..", "lib", "websockify"),
|
||||
os.path.join(wsdir, ".."),
|
||||
wsdir,
|
||||
]
|
||||
self.rebinder = None
|
||||
|
||||
for rdir in rebinder_path:
|
||||
|
|
@ -329,17 +353,22 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
|||
self.target_host = "127.0.0.1" # Loopback
|
||||
# Find a free high port
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(('', 0))
|
||||
sock.bind(("", 0))
|
||||
self.target_port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
|
||||
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
|
||||
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ])
|
||||
ld_preloads = filter(
|
||||
None, [self.rebinder, os.environ.get("LD_PRELOAD", None)]
|
||||
)
|
||||
|
||||
os.environ.update({
|
||||
"LD_PRELOAD": os.pathsep.join(ld_preloads),
|
||||
"REBIND_OLD_PORT": str(kwargs['listen_port']),
|
||||
"REBIND_NEW_PORT": str(self.target_port)})
|
||||
os.environ.update(
|
||||
{
|
||||
"LD_PRELOAD": os.pathsep.join(ld_preloads),
|
||||
"REBIND_OLD_PORT": str(kwargs["listen_port"]),
|
||||
"REBIND_NEW_PORT": str(self.target_port),
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(RequestHandlerClass, *args, **kwargs)
|
||||
|
||||
|
|
@ -348,7 +377,8 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
|||
self.wrap_times.append(time.time())
|
||||
self.wrap_times.pop(0)
|
||||
self.cmd = subprocess.Popen(
|
||||
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
|
||||
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup
|
||||
)
|
||||
self.spawn_message = True
|
||||
|
||||
def started(self):
|
||||
|
|
@ -371,10 +401,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
|||
|
||||
if self.token_plugin:
|
||||
msg = " - proxying from %s to targets generated by %s" % (
|
||||
src_string, type(self.token_plugin).__name__)
|
||||
src_string,
|
||||
type(self.token_plugin).__name__,
|
||||
)
|
||||
else:
|
||||
msg = " - proxying from %s to %s" % (
|
||||
src_string, dst_string)
|
||||
msg = " - proxying from %s to %s" % (src_string, dst_string)
|
||||
|
||||
if self.ssl_target:
|
||||
msg += " (using SSL)"
|
||||
|
|
@ -401,7 +432,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
|||
sys.exit(ret)
|
||||
elif self.wrap_mode == "respawn":
|
||||
now = time.time()
|
||||
avg = sum(self.wrap_times)/len(self.wrap_times)
|
||||
avg = sum(self.wrap_times) / len(self.wrap_times)
|
||||
if (now - avg) < 10:
|
||||
# 3 times in the last 10 seconds
|
||||
if self.spawn_message:
|
||||
|
|
@ -418,15 +449,25 @@ def _subprocess_setup():
|
|||
|
||||
|
||||
SSL_OPTIONS = {
|
||||
'default': ssl.OP_ALL,
|
||||
'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
|
||||
ssl.OP_NO_TLSv1,
|
||||
'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
|
||||
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1,
|
||||
'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
|
||||
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2,
|
||||
"default": ssl.OP_ALL,
|
||||
"tlsv1_1": ssl.PROTOCOL_SSLv23
|
||||
| ssl.OP_NO_SSLv2
|
||||
| ssl.OP_NO_SSLv3
|
||||
| ssl.OP_NO_TLSv1,
|
||||
"tlsv1_2": ssl.PROTOCOL_SSLv23
|
||||
| ssl.OP_NO_SSLv2
|
||||
| ssl.OP_NO_SSLv3
|
||||
| ssl.OP_NO_TLSv1
|
||||
| ssl.OP_NO_TLSv1_1,
|
||||
"tlsv1_3": ssl.PROTOCOL_SSLv23
|
||||
| ssl.OP_NO_SSLv2
|
||||
| ssl.OP_NO_SSLv3
|
||||
| ssl.OP_NO_TLSv1
|
||||
| ssl.OP_NO_TLSv1_1
|
||||
| ssl.OP_NO_TLSv1_2,
|
||||
}
|
||||
|
||||
|
||||
def select_ssl_version(version):
|
||||
"""Returns SSL options for the most secure TSL version available on this
|
||||
Python version"""
|
||||
|
|
@ -439,11 +480,11 @@ def select_ssl_version(version):
|
|||
keys.sort()
|
||||
fallback = keys[-1]
|
||||
logger = logging.getLogger(WebSocketProxy.log_prefix)
|
||||
logger.warn("TLS version %s unsupported. Falling back to %s",
|
||||
version, fallback)
|
||||
logger.warn("TLS version %s unsupported. Falling back to %s", version, fallback)
|
||||
|
||||
return SSL_OPTIONS[fallback]
|
||||
|
||||
|
||||
def websockify_init():
|
||||
# Setup basic logging to stderr.
|
||||
stderr_handler = logging.StreamHandler()
|
||||
|
|
@ -464,105 +505,198 @@ def websockify_init():
|
|||
usage += "\n %prog [options]"
|
||||
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
parser.add_option("--verbose", "-v", action="store_true",
|
||||
help="verbose messages")
|
||||
parser.add_option("--traffic", action="store_true",
|
||||
help="per frame traffic")
|
||||
parser.add_option("--record",
|
||||
help="record sessions to FILE.[session_number]", metavar="FILE")
|
||||
parser.add_option("--daemon", "-D",
|
||||
dest="daemon", action="store_true",
|
||||
help="become a daemon (background process)")
|
||||
parser.add_option("--run-once", action="store_true",
|
||||
help="handle a single WebSocket connection and exit")
|
||||
parser.add_option("--timeout", type=int, default=0,
|
||||
help="after TIMEOUT seconds exit when not connected")
|
||||
parser.add_option("--idle-timeout", type=int, default=0,
|
||||
help="server exits after TIMEOUT seconds if there are no "
|
||||
"active connections")
|
||||
parser.add_option("--cert", default="self.pem",
|
||||
help="SSL certificate file")
|
||||
parser.add_option("--key", default=None,
|
||||
help="SSL key file (if separate from cert)")
|
||||
parser.add_option("--key-password", default=None,
|
||||
help="SSL key password")
|
||||
parser.add_option("--ssl-only", action="store_true",
|
||||
help="disallow non-encrypted client connections")
|
||||
parser.add_option("--ssl-target", action="store_true",
|
||||
help="connect to SSL target as SSL client")
|
||||
parser.add_option("--verify-client", action="store_true",
|
||||
help="require encrypted client to present a valid certificate "
|
||||
"(needs Python 2.7.9 or newer or Python 3.4 or newer)")
|
||||
parser.add_option("--cafile", metavar="FILE",
|
||||
help="file of concatenated certificates of authorities trusted "
|
||||
"for validating clients (only effective with --verify-client). "
|
||||
"If omitted, system default list of CAs is used.")
|
||||
parser.add_option("--ssl-version", type="choice", default="default",
|
||||
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store",
|
||||
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)")
|
||||
parser.add_option("--ssl-ciphers", action="store",
|
||||
help="list of ciphers allowed for connection. For a list of "
|
||||
"supported ciphers run `openssl ciphers`")
|
||||
parser.add_option("--unix-listen",
|
||||
help="listen to unix socket", metavar="FILE", default=None)
|
||||
parser.add_option("--unix-listen-mode", default=None,
|
||||
help="specify mode for unix socket (defaults to 0600)")
|
||||
parser.add_option("--unix-target",
|
||||
help="connect to unix socket target", metavar="FILE")
|
||||
parser.add_option("--inetd",
|
||||
help="inetd mode, receive listening socket from stdin", action="store_true")
|
||||
parser.add_option("--web", default=None, metavar="DIR",
|
||||
help="run webserver on same port. Serve files from DIR.")
|
||||
parser.add_option("--web-auth", action="store_true",
|
||||
help="require authentication to access webserver.")
|
||||
parser.add_option("--wrap-mode", default="exit", metavar="MODE",
|
||||
choices=["exit", "ignore", "respawn"],
|
||||
help="action to take when the wrapped program exits "
|
||||
"or daemonizes: exit (default), ignore, respawn")
|
||||
parser.add_option("--prefer-ipv6", "-6",
|
||||
action="store_true", dest="source_is_ipv6",
|
||||
help="prefer IPv6 when resolving source_addr")
|
||||
parser.add_option("--libserver", action="store_true",
|
||||
help="use Python library SocketServer engine")
|
||||
parser.add_option("--target-config", metavar="FILE",
|
||||
dest="target_cfg",
|
||||
help="Configuration file containing valid targets "
|
||||
"in the form 'token: host:port' or, alternatively, a "
|
||||
"directory containing configuration files of this form "
|
||||
"(DEPRECATED: use `--token-plugin TokenFile --token-source "
|
||||
" path/to/token/file` instead)")
|
||||
parser.add_option("--token-plugin", default=None, metavar="CLASS",
|
||||
help="use a Python class, usually one from websockify.token_plugins, "
|
||||
"such as TokenFile, to process tokens into host:port pairs")
|
||||
parser.add_option("--token-source", default=None, metavar="ARG",
|
||||
help="an argument to be passed to the token plugin "
|
||||
"on instantiation")
|
||||
parser.add_option("--host-token", action="store_true",
|
||||
help="use the host HTTP header as token instead of the "
|
||||
"token URL query parameter")
|
||||
parser.add_option("--auth-plugin", default=None, metavar="CLASS",
|
||||
help="use a Python class, usually one from websockify.auth_plugins, "
|
||||
"such as BasicHTTPAuth, to determine if a connection is allowed")
|
||||
parser.add_option("--auth-source", default=None, metavar="ARG",
|
||||
help="an argument to be passed to the auth plugin "
|
||||
"on instantiation")
|
||||
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
|
||||
help="send a ping to the client every INTERVAL seconds")
|
||||
parser.add_option("--log-file", metavar="FILE",
|
||||
dest="log_file",
|
||||
help="File where logs will be saved")
|
||||
parser.add_option("--syslog", default=None, metavar="SERVER",
|
||||
help="Log to syslog server. SERVER can be local socket, "
|
||||
"such as /dev/log, or a UDP host:port pair.")
|
||||
parser.add_option("--legacy-syslog", action="store_true",
|
||||
help="Use the old syslog protocol instead of RFC 5424. "
|
||||
"Use this if the messages produced by websockify seem abnormal.")
|
||||
parser.add_option("--file-only", action="store_true",
|
||||
help="use this to disable directory listings in web server.")
|
||||
parser.add_option("--verbose", "-v", action="store_true", help="verbose messages")
|
||||
parser.add_option("--traffic", action="store_true", help="per frame traffic")
|
||||
parser.add_option(
|
||||
"--record", help="record sessions to FILE.[session_number]", metavar="FILE"
|
||||
)
|
||||
parser.add_option(
|
||||
"--daemon",
|
||||
"-D",
|
||||
dest="daemon",
|
||||
action="store_true",
|
||||
help="become a daemon (background process)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--run-once",
|
||||
action="store_true",
|
||||
help="handle a single WebSocket connection and exit",
|
||||
)
|
||||
parser.add_option(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=0,
|
||||
help="after TIMEOUT seconds exit when not connected",
|
||||
)
|
||||
parser.add_option(
|
||||
"--idle-timeout",
|
||||
type=int,
|
||||
default=0,
|
||||
help="server exits after TIMEOUT seconds if there are no active connections",
|
||||
)
|
||||
parser.add_option("--cert", default="self.pem", help="SSL certificate file")
|
||||
parser.add_option(
|
||||
"--key", default=None, help="SSL key file (if separate from cert)"
|
||||
)
|
||||
parser.add_option("--key-password", default=None, help="SSL key password")
|
||||
parser.add_option(
|
||||
"--ssl-only",
|
||||
action="store_true",
|
||||
help="disallow non-encrypted client connections",
|
||||
)
|
||||
parser.add_option(
|
||||
"--ssl-target", action="store_true", help="connect to SSL target as SSL client"
|
||||
)
|
||||
parser.add_option(
|
||||
"--verify-client",
|
||||
action="store_true",
|
||||
help="require encrypted client to present a valid certificate "
|
||||
"(needs Python 2.7.9 or newer or Python 3.4 or newer)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--cafile",
|
||||
metavar="FILE",
|
||||
help="file of concatenated certificates of authorities trusted "
|
||||
"for validating clients (only effective with --verify-client). "
|
||||
"If omitted, system default list of CAs is used.",
|
||||
)
|
||||
parser.add_option(
|
||||
"--ssl-version",
|
||||
type="choice",
|
||||
default="default",
|
||||
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"],
|
||||
action="store",
|
||||
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--ssl-ciphers",
|
||||
action="store",
|
||||
help="list of ciphers allowed for connection. For a list of "
|
||||
"supported ciphers run `openssl ciphers`",
|
||||
)
|
||||
parser.add_option(
|
||||
"--unix-listen", help="listen to unix socket", metavar="FILE", default=None
|
||||
)
|
||||
parser.add_option(
|
||||
"--unix-listen-mode",
|
||||
default=None,
|
||||
help="specify mode for unix socket (defaults to 0600)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--unix-target", help="connect to unix socket target", metavar="FILE"
|
||||
)
|
||||
parser.add_option(
|
||||
"--inetd",
|
||||
help="inetd mode, receive listening socket from stdin",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_option(
|
||||
"--web",
|
||||
default=None,
|
||||
metavar="DIR",
|
||||
help="run webserver on same port. Serve files from DIR.",
|
||||
)
|
||||
parser.add_option(
|
||||
"--web-auth",
|
||||
action="store_true",
|
||||
help="require authentication to access webserver.",
|
||||
)
|
||||
parser.add_option(
|
||||
"--wrap-mode",
|
||||
default="exit",
|
||||
metavar="MODE",
|
||||
choices=["exit", "ignore", "respawn"],
|
||||
help="action to take when the wrapped program exits "
|
||||
"or daemonizes: exit (default), ignore, respawn",
|
||||
)
|
||||
parser.add_option(
|
||||
"--prefer-ipv6",
|
||||
"-6",
|
||||
action="store_true",
|
||||
dest="source_is_ipv6",
|
||||
help="prefer IPv6 when resolving source_addr",
|
||||
)
|
||||
parser.add_option(
|
||||
"--libserver",
|
||||
action="store_true",
|
||||
help="use Python library SocketServer engine",
|
||||
)
|
||||
parser.add_option(
|
||||
"--target-config",
|
||||
metavar="FILE",
|
||||
dest="target_cfg",
|
||||
help="Configuration file containing valid targets "
|
||||
"in the form 'token: host:port' or, alternatively, a "
|
||||
"directory containing configuration files of this form "
|
||||
"(DEPRECATED: use `--token-plugin TokenFile --token-source "
|
||||
" path/to/token/file` instead)",
|
||||
)
|
||||
parser.add_option(
|
||||
"--token-plugin",
|
||||
default=None,
|
||||
metavar="CLASS",
|
||||
help="use a Python class, usually one from websockify.token_plugins, "
|
||||
"such as TokenFile, to process tokens into host:port pairs",
|
||||
)
|
||||
parser.add_option(
|
||||
"--token-source",
|
||||
default=None,
|
||||
metavar="ARG",
|
||||
help="an argument to be passed to the token plugin on instantiation",
|
||||
)
|
||||
parser.add_option(
|
||||
"--host-token",
|
||||
action="store_true",
|
||||
help="use the host HTTP header as token instead of the "
|
||||
"token URL query parameter",
|
||||
)
|
||||
parser.add_option(
|
||||
"--auth-plugin",
|
||||
default=None,
|
||||
metavar="CLASS",
|
||||
help="use a Python class, usually one from websockify.auth_plugins, "
|
||||
"such as BasicHTTPAuth, to determine if a connection is allowed",
|
||||
)
|
||||
parser.add_option(
|
||||
"--auth-source",
|
||||
default=None,
|
||||
metavar="ARG",
|
||||
help="an argument to be passed to the auth plugin on instantiation",
|
||||
)
|
||||
parser.add_option(
|
||||
"--heartbeat",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="INTERVAL",
|
||||
help="send a ping to the client every INTERVAL seconds",
|
||||
)
|
||||
parser.add_option(
|
||||
"--log-file",
|
||||
metavar="FILE",
|
||||
dest="log_file",
|
||||
help="File where logs will be saved",
|
||||
)
|
||||
parser.add_option(
|
||||
"--syslog",
|
||||
default=None,
|
||||
metavar="SERVER",
|
||||
help="Log to syslog server. SERVER can be local socket, "
|
||||
"such as /dev/log, or a UDP host:port pair.",
|
||||
)
|
||||
parser.add_option(
|
||||
"--legacy-syslog",
|
||||
action="store_true",
|
||||
help="Use the old syslog protocol instead of RFC 5424. "
|
||||
"Use this if the messages produced by websockify seem abnormal.",
|
||||
)
|
||||
parser.add_option(
|
||||
"--file-only",
|
||||
action="store_true",
|
||||
help="use this to disable directory listings in web server.",
|
||||
)
|
||||
|
||||
(opts, args) = parser.parse_args()
|
||||
|
||||
|
||||
# Validate options.
|
||||
|
||||
if opts.token_source and not opts.token_plugin:
|
||||
|
|
@ -583,11 +717,9 @@ def websockify_init():
|
|||
if opts.legacy_syslog and not opts.syslog:
|
||||
parser.error("You must use --syslog to use --legacy-syslog")
|
||||
|
||||
|
||||
opts.ssl_options = select_ssl_version(opts.ssl_version)
|
||||
del opts.ssl_version
|
||||
|
||||
|
||||
if opts.log_file:
|
||||
# Setup logging to user-specified file.
|
||||
opts.log_file = os.path.abspath(opts.log_file)
|
||||
|
|
@ -601,9 +733,9 @@ def websockify_init():
|
|||
|
||||
if opts.syslog:
|
||||
# Determine how to connect to syslog...
|
||||
if opts.syslog.count(':'):
|
||||
if opts.syslog.count(":"):
|
||||
# User supplied a host:port pair.
|
||||
syslog_host, syslog_port = opts.syslog.rsplit(':', 1)
|
||||
syslog_host, syslog_port = opts.syslog.rsplit(":", 1)
|
||||
try:
|
||||
syslog_port = int(syslog_port)
|
||||
except ValueError:
|
||||
|
|
@ -622,10 +754,12 @@ def websockify_init():
|
|||
syslog_facility = WebsockifySysLogHandler.LOG_USER
|
||||
|
||||
# Start logging to syslog.
|
||||
syslog_handler = WebsockifySysLogHandler(address=syslog_dest,
|
||||
facility=syslog_facility,
|
||||
ident='websockify',
|
||||
legacy=opts.legacy_syslog)
|
||||
syslog_handler = WebsockifySysLogHandler(
|
||||
address=syslog_dest,
|
||||
facility=syslog_facility,
|
||||
ident="websockify",
|
||||
legacy=opts.legacy_syslog,
|
||||
)
|
||||
syslog_handler.setLevel(logging.DEBUG)
|
||||
syslog_handler.setFormatter(log_formatter)
|
||||
root = logging.getLogger()
|
||||
|
|
@ -638,24 +772,23 @@ def websockify_init():
|
|||
root = logging.getLogger()
|
||||
root.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
# Transform to absolute path as daemon may chdir
|
||||
if opts.target_cfg:
|
||||
opts.target_cfg = os.path.abspath(opts.target_cfg)
|
||||
|
||||
if opts.target_cfg:
|
||||
opts.token_plugin = 'TokenFile'
|
||||
opts.token_plugin = "TokenFile"
|
||||
opts.token_source = opts.target_cfg
|
||||
|
||||
del opts.target_cfg
|
||||
|
||||
if sys.argv.count('--'):
|
||||
if sys.argv.count("--"):
|
||||
opts.wrap_cmd = args[1:]
|
||||
else:
|
||||
opts.wrap_cmd = None
|
||||
|
||||
if not websockifyserver.ssl and opts.ssl_target:
|
||||
parser.error("SSL target requested and Python SSL module not loaded.");
|
||||
parser.error("SSL target requested and Python SSL module not loaded.")
|
||||
|
||||
if opts.ssl_only and not os.path.exists(opts.cert):
|
||||
parser.error("SSL only and %s not found" % opts.cert)
|
||||
|
|
@ -677,11 +810,11 @@ def websockify_init():
|
|||
parser.error("Too few arguments")
|
||||
arg = args.pop(0)
|
||||
# Parse host:port and convert ports to numbers
|
||||
if arg.count(':') > 0:
|
||||
opts.listen_host, opts.listen_port = arg.rsplit(':', 1)
|
||||
opts.listen_host = opts.listen_host.strip('[]')
|
||||
if arg.count(":") > 0:
|
||||
opts.listen_host, opts.listen_port = arg.rsplit(":", 1)
|
||||
opts.listen_host = opts.listen_host.strip("[]")
|
||||
else:
|
||||
opts.listen_host, opts.listen_port = '', arg
|
||||
opts.listen_host, opts.listen_port = "", arg
|
||||
|
||||
try:
|
||||
opts.listen_port = int(opts.listen_port)
|
||||
|
|
@ -697,9 +830,9 @@ def websockify_init():
|
|||
if len(args) < 1:
|
||||
parser.error("Too few arguments")
|
||||
arg = args.pop(0)
|
||||
if arg.count(':') > 0:
|
||||
opts.target_host, opts.target_port = arg.rsplit(':', 1)
|
||||
opts.target_host = opts.target_host.strip('[]')
|
||||
if arg.count(":") > 0:
|
||||
opts.target_host, opts.target_port = arg.rsplit(":", 1)
|
||||
opts.target_host = opts.target_host.strip("[]")
|
||||
else:
|
||||
parser.error("Error parsing target")
|
||||
|
||||
|
|
@ -712,11 +845,10 @@ def websockify_init():
|
|||
parser.error("Too many arguments")
|
||||
|
||||
if opts.token_plugin is not None:
|
||||
if '.' not in opts.token_plugin:
|
||||
opts.token_plugin = (
|
||||
'websockify.token_plugins.%s' % opts.token_plugin)
|
||||
if "." not in opts.token_plugin:
|
||||
opts.token_plugin = "websockify.token_plugins.%s" % opts.token_plugin
|
||||
|
||||
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
|
||||
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit(".", 1)
|
||||
|
||||
__import__(token_plugin_module)
|
||||
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
|
||||
|
|
@ -726,10 +858,10 @@ def websockify_init():
|
|||
del opts.token_source
|
||||
|
||||
if opts.auth_plugin is not None:
|
||||
if '.' not in opts.auth_plugin:
|
||||
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
|
||||
if "." not in opts.auth_plugin:
|
||||
opts.auth_plugin = "websockify.auth_plugins.%s" % opts.auth_plugin
|
||||
|
||||
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
|
||||
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit(".", 1)
|
||||
|
||||
__import__(auth_plugin_module)
|
||||
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
|
||||
|
|
@ -759,32 +891,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
|
|||
|
||||
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
|
||||
# Save off proxy specific options
|
||||
self.target_host = kwargs.pop('target_host', None)
|
||||
self.target_port = kwargs.pop('target_port', None)
|
||||
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
|
||||
self.wrap_mode = kwargs.pop('wrap_mode', None)
|
||||
self.unix_target = kwargs.pop('unix_target', None)
|
||||
self.ssl_target = kwargs.pop('ssl_target', None)
|
||||
self.token_plugin = kwargs.pop('token_plugin', None)
|
||||
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
||||
self.heartbeat = kwargs.pop('heartbeat', None)
|
||||
self.target_host = kwargs.pop("target_host", None)
|
||||
self.target_port = kwargs.pop("target_port", None)
|
||||
self.wrap_cmd = kwargs.pop("wrap_cmd", None)
|
||||
self.wrap_mode = kwargs.pop("wrap_mode", None)
|
||||
self.unix_target = kwargs.pop("unix_target", None)
|
||||
self.ssl_target = kwargs.pop("ssl_target", None)
|
||||
self.token_plugin = kwargs.pop("token_plugin", None)
|
||||
self.auth_plugin = kwargs.pop("auth_plugin", None)
|
||||
self.heartbeat = kwargs.pop("heartbeat", None)
|
||||
|
||||
self.token_plugin = None
|
||||
self.auth_plugin = None
|
||||
self.daemon = False
|
||||
|
||||
# Server configuration
|
||||
listen_host = kwargs.pop('listen_host', '')
|
||||
listen_port = kwargs.pop('listen_port', None)
|
||||
web = kwargs.pop('web', '')
|
||||
listen_host = kwargs.pop("listen_host", "")
|
||||
listen_port = kwargs.pop("listen_port", None)
|
||||
web = kwargs.pop("web", "")
|
||||
|
||||
# Configuration affecting base request handler
|
||||
self.only_upgrade = not web
|
||||
self.verbose = kwargs.pop('verbose', False)
|
||||
record = kwargs.pop('record', '')
|
||||
self.only_upgrade = not web
|
||||
self.verbose = kwargs.pop("verbose", False)
|
||||
record = kwargs.pop("record", "")
|
||||
if record:
|
||||
self.record = os.path.abspath(record)
|
||||
self.run_once = kwargs.pop('run_once', False)
|
||||
self.run_once = kwargs.pop("run_once", False)
|
||||
self.handler_id = 0
|
||||
|
||||
for arg in kwargs.keys():
|
||||
|
|
@ -795,12 +927,11 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
|
|||
|
||||
super().__init__((listen_host, listen_port), RequestHandlerClass)
|
||||
|
||||
|
||||
def process_request(self, request, client_address):
|
||||
"""Override process_request to implement a counter"""
|
||||
self.handler_id += 1
|
||||
super().process_request(request, client_address)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
websockify_init()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
Python WebSocket server base
|
||||
Copyright 2011 Joel Martin
|
||||
Copyright 2016-2018 Pierre Ossman
|
||||
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
|
||||
'''
|
||||
"""
|
||||
|
||||
import sys
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
||||
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError
|
||||
from websockify.websocket import (
|
||||
WebSocket,
|
||||
WebSocketWantReadError,
|
||||
WebSocketWantWriteError,
|
||||
)
|
||||
|
||||
|
||||
class HttpWebSocket(WebSocket):
|
||||
"""Class to glue websocket and http request functionality together"""
|
||||
|
||||
def __init__(self, request_handler):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -62,8 +68,10 @@ class WebSocketRequestHandlerMixIn:
|
|||
# Checks if it is a websocket request and redirects
|
||||
self.do_GET = self._real_do_GET
|
||||
|
||||
if (self.headers.get('upgrade') and
|
||||
self.headers.get('upgrade').lower() == 'websocket'):
|
||||
if (
|
||||
self.headers.get("upgrade")
|
||||
and self.headers.get("upgrade").lower() == "websocket"
|
||||
):
|
||||
self.handle_upgrade()
|
||||
else:
|
||||
self.do_GET()
|
||||
|
|
@ -93,18 +101,20 @@ class WebSocketRequestHandlerMixIn:
|
|||
|
||||
def handle_websocket(self):
|
||||
"""Handle a WebSocket connection.
|
||||
|
||||
|
||||
This is called when the WebSocket is ready to be used. A
|
||||
sub-class should perform the necessary communication here and
|
||||
return once done.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Convenient ready made classes
|
||||
|
||||
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
|
||||
BaseHTTPRequestHandler):
|
||||
|
||||
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, BaseHTTPRequestHandler):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketServer(HTTPServer):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
Python WebSocket server base with support for "wss://" encryption.
|
||||
Copyright 2011 Joel Martin
|
||||
Copyright 2016 Pierre Ossman
|
||||
|
|
@ -10,35 +10,39 @@ You can make a cert/key with openssl using:
|
|||
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
|
||||
as taken from http://docs.python.org/dev/library/ssl.html#certificates
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
import os, sys, time, errno, signal, socket, select, logging
|
||||
import multiprocessing
|
||||
from http.server import SimpleHTTPRequestHandler
|
||||
|
||||
# Degraded functionality if these imports are missing
|
||||
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'),
|
||||
('resource', 'daemonizing is disabled')]:
|
||||
for mod, msg in [
|
||||
("ssl", "TLS/SSL/wss is disabled"),
|
||||
("resource", "daemonizing is disabled"),
|
||||
]:
|
||||
try:
|
||||
globals()[mod] = __import__(mod)
|
||||
except ImportError:
|
||||
globals()[mod] = None
|
||||
print("WARNING: no '%s' module, %s" % (mod, msg))
|
||||
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
# make sockets pickle-able/inheritable
|
||||
import multiprocessing.reduction
|
||||
|
||||
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
|
||||
from websockify.websocketserver import WebSocketRequestHandlerMixIn
|
||||
|
||||
|
||||
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
|
||||
def select_subprotocol(self, protocols):
|
||||
# Handle old websockify clients that still specify a sub-protocol
|
||||
if 'binary' in protocols:
|
||||
return 'binary'
|
||||
if "binary" in protocols:
|
||||
return "binary"
|
||||
else:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
# HTTP handler with WebSocket upgrade support
|
||||
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
|
||||
|
|
@ -56,6 +60,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
* run_once: Handle a single request
|
||||
* handler_id: A sequence number for this connection, appended to record filename
|
||||
"""
|
||||
|
||||
server_version = "WebSockify"
|
||||
|
||||
protocol_version = "HTTP/1.1"
|
||||
|
|
@ -73,7 +78,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
self.daemon = getattr(server, "daemon", False)
|
||||
self.record = getattr(server, "record", False)
|
||||
self.run_once = getattr(server, "run_once", False)
|
||||
self.rec = None
|
||||
self.rec = None
|
||||
self.handler_id = getattr(server, "handler_id", False)
|
||||
self.file_only = getattr(server, "file_only", False)
|
||||
self.traffic = getattr(server, "traffic", False)
|
||||
|
|
@ -87,30 +92,33 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
super().__init__(req, addr, server)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args))
|
||||
self.logger.info(
|
||||
"%s - - [%s] %s"
|
||||
% (self.client_address[0], self.log_date_time_string(), format % args)
|
||||
)
|
||||
|
||||
#
|
||||
# WebSocketRequestHandler logging/output functions
|
||||
#
|
||||
|
||||
def print_traffic(self, token="."):
|
||||
""" Show traffic flow mode. """
|
||||
"""Show traffic flow mode."""
|
||||
if self.traffic:
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def msg(self, msg, *args, **kwargs):
|
||||
""" Output message with handler_id prefix. """
|
||||
"""Output message with handler_id prefix."""
|
||||
prefix = "% 3d: " % self.handler_id
|
||||
self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
|
||||
|
||||
def vmsg(self, msg, *args, **kwargs):
|
||||
""" Same as msg() but as debug. """
|
||||
"""Same as msg() but as debug."""
|
||||
prefix = "% 3d: " % self.handler_id
|
||||
self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
|
||||
|
||||
def warn(self, msg, *args, **kwargs):
|
||||
""" Same as msg() but as warning. """
|
||||
"""Same as msg() but as warning."""
|
||||
prefix = "% 3d: " % self.handler_id
|
||||
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
|
||||
|
||||
|
|
@ -118,19 +126,24 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
# Main WebSocketRequestHandler methods
|
||||
#
|
||||
def send_frames(self, bufs=None):
|
||||
""" Encode and send WebSocket frames. Any frames already
|
||||
"""Encode and send WebSocket frames. Any frames already
|
||||
queued will be sent first. If buf is not set then only queued
|
||||
frames will be sent. Returns True if any frames could not be
|
||||
fully sent, in which case the caller should call again when
|
||||
the socket is ready. """
|
||||
the socket is ready."""
|
||||
|
||||
tdelta = int(time.time()*1000) - self.start_time
|
||||
tdelta = int(time.time() * 1000) - self.start_time
|
||||
|
||||
if bufs:
|
||||
for buf in bufs:
|
||||
if self.rec:
|
||||
# Python 3 compatible conversion
|
||||
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
|
||||
bufstr = (
|
||||
buf.decode("latin1")
|
||||
.encode("unicode_escape")
|
||||
.decode("ascii")
|
||||
.replace("'", "\\'")
|
||||
)
|
||||
self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr))
|
||||
self.send_parts.append(buf)
|
||||
|
||||
|
|
@ -147,7 +160,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
return False
|
||||
|
||||
def recv_frames(self):
|
||||
""" Receive and decode WebSocket frames.
|
||||
"""Receive and decode WebSocket frames.
|
||||
|
||||
Returns:
|
||||
(bufs_list, closed_string)
|
||||
|
|
@ -155,7 +168,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
|
||||
closed = False
|
||||
bufs = []
|
||||
tdelta = int(time.time()*1000) - self.start_time
|
||||
tdelta = int(time.time() * 1000) - self.start_time
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -165,15 +178,22 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
break
|
||||
|
||||
if buf is None:
|
||||
closed = {'code': self.request.close_code,
|
||||
'reason': self.request.close_reason}
|
||||
closed = {
|
||||
"code": self.request.close_code,
|
||||
"reason": self.request.close_reason,
|
||||
}
|
||||
return bufs, closed
|
||||
|
||||
self.print_traffic("}")
|
||||
|
||||
if self.rec:
|
||||
# Python 3 compatible conversion
|
||||
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
|
||||
bufstr = (
|
||||
buf.decode("latin1")
|
||||
.encode("unicode_escape")
|
||||
.decode("ascii")
|
||||
.replace("'", "\\'")
|
||||
)
|
||||
self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr))
|
||||
|
||||
bufs.append(buf)
|
||||
|
|
@ -183,16 +203,16 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
|
||||
return bufs, closed
|
||||
|
||||
def send_close(self, code=1000, reason=''):
|
||||
""" Send a WebSocket orderly close frame. """
|
||||
def send_close(self, code=1000, reason=""):
|
||||
"""Send a WebSocket orderly close frame."""
|
||||
self.request.shutdown(socket.SHUT_RDWR, code, reason)
|
||||
|
||||
def send_pong(self, data=b''):
|
||||
""" Send a WebSocket pong frame. """
|
||||
def send_pong(self, data=b""):
|
||||
"""Send a WebSocket pong frame."""
|
||||
self.request.pong(data)
|
||||
|
||||
def send_ping(self, data=b''):
|
||||
""" Send a WebSocket ping frame. """
|
||||
def send_ping(self, data=b""):
|
||||
"""Send a WebSocket ping frame."""
|
||||
self.request.ping(data)
|
||||
|
||||
def handle_upgrade(self):
|
||||
|
|
@ -207,8 +227,8 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
self.server.ws_connection = True
|
||||
# Initialize per client settings
|
||||
self.send_parts = []
|
||||
self.recv_part = None
|
||||
self.start_time = int(time.time()*1000)
|
||||
self.recv_part = None
|
||||
self.start_time = int(time.time() * 1000)
|
||||
|
||||
# client_address is empty with, say, UNIX domain sockets
|
||||
client_addr = ""
|
||||
|
|
@ -224,17 +244,15 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
else:
|
||||
self.stype = "Plain non-SSL (ws://)"
|
||||
|
||||
self.log_message("%s: %s WebSocket connection", client_addr,
|
||||
self.stype)
|
||||
if self.path != '/':
|
||||
self.log_message("%s: %s WebSocket connection", client_addr, self.stype)
|
||||
if self.path != "/":
|
||||
self.log_message("%s: Path: '%s'", client_addr, self.path)
|
||||
|
||||
if self.record:
|
||||
# Record raw frame data as JavaScript array
|
||||
fname = "%s.%s" % (self.record,
|
||||
self.handler_id)
|
||||
fname = "%s.%s" % (self.record, self.handler_id)
|
||||
self.log_message("opening record file: %s", fname)
|
||||
self.rec = open(fname, 'w+')
|
||||
self.rec = open(fname, "w+")
|
||||
self.rec.write("var VNC_frame_data = [\n")
|
||||
|
||||
try:
|
||||
|
|
@ -261,15 +279,17 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
return super().list_directory(path)
|
||||
|
||||
def new_websocket_client(self):
|
||||
""" Do something with a WebSockets client connection. """
|
||||
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
|
||||
"""Do something with a WebSockets client connection."""
|
||||
raise Exception(
|
||||
"WebSocketRequestHandler.new_websocket_client() must be overloaded"
|
||||
)
|
||||
|
||||
def validate_connection(self):
|
||||
""" Ensure that the connection has a valid token, and set the target. """
|
||||
"""Ensure that the connection has a valid token, and set the target."""
|
||||
pass
|
||||
|
||||
def auth_connection(self):
|
||||
""" Ensure that the connection is authorized. """
|
||||
"""Ensure that the connection is authorized."""
|
||||
pass
|
||||
|
||||
def do_HEAD(self):
|
||||
|
|
@ -296,12 +316,12 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
|
|||
else:
|
||||
super().handle()
|
||||
|
||||
def log_request(self, code='-', size='-'):
|
||||
def log_request(self, code="-", size="-"):
|
||||
if self.verbose:
|
||||
super().log_request(code, size)
|
||||
|
||||
|
||||
class WebSockifyServer():
|
||||
class WebSockifyServer:
|
||||
"""
|
||||
WebSockets server class.
|
||||
As an alternative, the standard library SocketServer can be used
|
||||
|
|
@ -317,48 +337,69 @@ class WebSockifyServer():
|
|||
class Terminate(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, RequestHandlerClass, listen_fd=None,
|
||||
listen_host='', listen_port=None, source_is_ipv6=False,
|
||||
verbose=False, cert='', key='', key_password=None, ssl_only=None,
|
||||
verify_client=False, cafile=None,
|
||||
daemon=False, record='', web='', web_auth=False,
|
||||
file_only=False,
|
||||
run_once=False, timeout=0, idle_timeout=0, traffic=False,
|
||||
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
|
||||
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
|
||||
unix_listen=None, unix_listen_mode=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
RequestHandlerClass,
|
||||
listen_fd=None,
|
||||
listen_host="",
|
||||
listen_port=None,
|
||||
source_is_ipv6=False,
|
||||
verbose=False,
|
||||
cert="",
|
||||
key="",
|
||||
key_password=None,
|
||||
ssl_only=None,
|
||||
verify_client=False,
|
||||
cafile=None,
|
||||
daemon=False,
|
||||
record="",
|
||||
web="",
|
||||
web_auth=False,
|
||||
file_only=False,
|
||||
run_once=False,
|
||||
timeout=0,
|
||||
idle_timeout=0,
|
||||
traffic=False,
|
||||
tcp_keepalive=True,
|
||||
tcp_keepcnt=None,
|
||||
tcp_keepidle=None,
|
||||
tcp_keepintvl=None,
|
||||
ssl_ciphers=None,
|
||||
ssl_options=0,
|
||||
unix_listen=None,
|
||||
unix_listen_mode=None,
|
||||
):
|
||||
# settings
|
||||
self.RequestHandlerClass = RequestHandlerClass
|
||||
self.verbose = verbose
|
||||
self.listen_fd = listen_fd
|
||||
self.unix_listen = unix_listen
|
||||
self.unix_listen_mode = unix_listen_mode
|
||||
self.listen_host = listen_host
|
||||
self.listen_port = listen_port
|
||||
self.prefer_ipv6 = source_is_ipv6
|
||||
self.ssl_only = ssl_only
|
||||
self.ssl_ciphers = ssl_ciphers
|
||||
self.ssl_options = ssl_options
|
||||
self.verify_client = verify_client
|
||||
self.daemon = daemon
|
||||
self.run_once = run_once
|
||||
self.timeout = timeout
|
||||
self.idle_timeout = idle_timeout
|
||||
self.traffic = traffic
|
||||
self.file_only = file_only
|
||||
self.web_auth = web_auth
|
||||
self.verbose = verbose
|
||||
self.listen_fd = listen_fd
|
||||
self.unix_listen = unix_listen
|
||||
self.unix_listen_mode = unix_listen_mode
|
||||
self.listen_host = listen_host
|
||||
self.listen_port = listen_port
|
||||
self.prefer_ipv6 = source_is_ipv6
|
||||
self.ssl_only = ssl_only
|
||||
self.ssl_ciphers = ssl_ciphers
|
||||
self.ssl_options = ssl_options
|
||||
self.verify_client = verify_client
|
||||
self.daemon = daemon
|
||||
self.run_once = run_once
|
||||
self.timeout = timeout
|
||||
self.idle_timeout = idle_timeout
|
||||
self.traffic = traffic
|
||||
self.file_only = file_only
|
||||
self.web_auth = web_auth
|
||||
|
||||
self.launch_time = time.time()
|
||||
self.ws_connection = False
|
||||
self.handler_id = 1
|
||||
self.terminating = False
|
||||
self.launch_time = time.time()
|
||||
self.ws_connection = False
|
||||
self.handler_id = 1
|
||||
self.terminating = False
|
||||
|
||||
self.logger = self.get_logger()
|
||||
self.tcp_keepalive = tcp_keepalive
|
||||
self.tcp_keepcnt = tcp_keepcnt
|
||||
self.tcp_keepidle = tcp_keepidle
|
||||
self.tcp_keepintvl = tcp_keepintvl
|
||||
self.logger = self.get_logger()
|
||||
self.tcp_keepalive = tcp_keepalive
|
||||
self.tcp_keepcnt = tcp_keepcnt
|
||||
self.tcp_keepidle = tcp_keepidle
|
||||
self.tcp_keepintvl = tcp_keepintvl
|
||||
|
||||
# keyfile path must be None if not specified
|
||||
self.key = None
|
||||
|
|
@ -366,7 +407,7 @@ class WebSockifyServer():
|
|||
|
||||
# Make paths settings absolute
|
||||
self.cert = os.path.abspath(cert)
|
||||
self.web = self.record = self.cafile = ''
|
||||
self.web = self.record = self.cafile = ""
|
||||
if key:
|
||||
self.key = os.path.abspath(key)
|
||||
if web:
|
||||
|
|
@ -393,11 +434,12 @@ class WebSockifyServer():
|
|||
elif self.unix_listen != None:
|
||||
self.msg(" - Listen on unix socket %s", self.unix_listen)
|
||||
else:
|
||||
self.msg(" - Listen on %s:%s",
|
||||
self.listen_host, self.listen_port)
|
||||
self.msg(" - Listen on %s:%s", self.listen_host, self.listen_port)
|
||||
if self.web:
|
||||
if self.file_only:
|
||||
self.msg(" - Web server (no directory listings). Web root: %s", self.web)
|
||||
self.msg(
|
||||
" - Web server (no directory listings). Web root: %s", self.web
|
||||
)
|
||||
else:
|
||||
self.msg(" - Web server. Web root: %s", self.web)
|
||||
if ssl:
|
||||
|
|
@ -420,34 +462,45 @@ class WebSockifyServer():
|
|||
|
||||
@staticmethod
|
||||
def get_logger():
|
||||
return logging.getLogger("%s.%s" % (
|
||||
WebSockifyServer.log_prefix,
|
||||
WebSockifyServer.__class__.__name__))
|
||||
return logging.getLogger(
|
||||
"%s.%s" % (WebSockifyServer.log_prefix, WebSockifyServer.__class__.__name__)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def socket(host, port=None, connect=False, prefer_ipv6=False,
|
||||
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
|
||||
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
|
||||
tcp_keepidle=None, tcp_keepintvl=None):
|
||||
""" Resolve a host (and optional port) to an IPv4 or IPv6
|
||||
def socket(
|
||||
host,
|
||||
port=None,
|
||||
connect=False,
|
||||
prefer_ipv6=False,
|
||||
unix_socket=None,
|
||||
unix_socket_mode=None,
|
||||
unix_socket_listen=False,
|
||||
use_ssl=False,
|
||||
tcp_keepalive=True,
|
||||
tcp_keepcnt=None,
|
||||
tcp_keepidle=None,
|
||||
tcp_keepintvl=None,
|
||||
):
|
||||
"""Resolve a host (and optional port) to an IPv4 or IPv6
|
||||
address. Create a socket. Bind to it if listen is set,
|
||||
otherwise connect to it. Return the socket.
|
||||
"""
|
||||
flags = 0
|
||||
if host == '':
|
||||
if host == "":
|
||||
host = None
|
||||
if connect and not (port or unix_socket):
|
||||
raise Exception("Connect mode requires a port")
|
||||
if use_ssl and not ssl:
|
||||
raise Exception("SSL socket requested but Python SSL module not loaded.");
|
||||
raise Exception("SSL socket requested but Python SSL module not loaded.")
|
||||
if not connect and use_ssl:
|
||||
raise Exception("SSL only supported in connect mode (for now)")
|
||||
if not connect:
|
||||
flags = flags | socket.AI_PASSIVE
|
||||
|
||||
if not unix_socket:
|
||||
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
|
||||
socket.IPPROTO_TCP, flags)
|
||||
addrs = socket.getaddrinfo(
|
||||
host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags
|
||||
)
|
||||
if not addrs:
|
||||
raise Exception("Could not resolve host '%s'" % host)
|
||||
addrs.sort(key=lambda x: x[0])
|
||||
|
|
@ -455,17 +508,14 @@ class WebSockifyServer():
|
|||
addrs.reverse()
|
||||
sock = socket.socket(addrs[0][0], addrs[0][1])
|
||||
|
||||
if tcp_keepalive:
|
||||
if tcp_keepalive:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
if tcp_keepcnt:
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
|
||||
tcp_keepcnt)
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
|
||||
if tcp_keepidle:
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
|
||||
tcp_keepidle)
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
|
||||
if tcp_keepintvl:
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
|
||||
tcp_keepintvl)
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
|
||||
|
||||
if connect:
|
||||
sock.connect(addrs[0][4])
|
||||
|
|
@ -497,8 +547,7 @@ class WebSockifyServer():
|
|||
return sock
|
||||
|
||||
@staticmethod
|
||||
def daemonize(keepfd=None, chdir='/'):
|
||||
|
||||
def daemonize(keepfd=None, chdir="/"):
|
||||
if keepfd is None:
|
||||
keepfd = []
|
||||
|
||||
|
|
@ -506,14 +555,16 @@ class WebSockifyServer():
|
|||
if chdir:
|
||||
os.chdir(chdir)
|
||||
else:
|
||||
os.chdir('/')
|
||||
os.chdir("/")
|
||||
os.setgid(os.getgid()) # relinquish elevations
|
||||
os.setuid(os.getuid()) # relinquish elevations
|
||||
|
||||
# Double fork to daemonize
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
os.setsid() # Obtain new process group
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
if os.fork() > 0:
|
||||
os._exit(0) # Parent exits
|
||||
os.setsid() # Obtain new process group
|
||||
if os.fork() > 0:
|
||||
os._exit(0) # Parent exits
|
||||
|
||||
# Signal handling
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
|
|
@ -521,14 +572,16 @@ class WebSockifyServer():
|
|||
|
||||
# Close open files
|
||||
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
|
||||
if maxfd == resource.RLIM_INFINITY: maxfd = 256
|
||||
if maxfd == resource.RLIM_INFINITY:
|
||||
maxfd = 256
|
||||
for fd in reversed(range(maxfd)):
|
||||
try:
|
||||
if fd not in keepfd:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
_, exc, _ = sys.exc_info()
|
||||
if exc.errno != errno.EBADF: raise
|
||||
if exc.errno != errno.EBADF:
|
||||
raise
|
||||
|
||||
# Redirect I/O to /dev/null
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
|
||||
|
|
@ -557,7 +610,7 @@ class WebSockifyServer():
|
|||
# Peek, but do not read the data so that we have a opportunity
|
||||
# to SSL wrap the socket first
|
||||
handshake = sock.recv(1024, socket.MSG_PEEK)
|
||||
#self.msg("Handshake [%s]" % handshake)
|
||||
# self.msg("Handshake [%s]" % handshake)
|
||||
|
||||
if not handshake:
|
||||
raise self.EClose("")
|
||||
|
|
@ -567,8 +620,7 @@ class WebSockifyServer():
|
|||
if not ssl:
|
||||
raise self.EClose("SSL connection but no 'ssl' module")
|
||||
if not os.path.exists(self.cert):
|
||||
raise self.EClose("SSL connection but '%s' not found"
|
||||
% self.cert)
|
||||
raise self.EClose("SSL connection but '%s' not found" % self.cert)
|
||||
retsock = None
|
||||
try:
|
||||
# create new-style SSL wrapping for extended features
|
||||
|
|
@ -576,16 +628,16 @@ class WebSockifyServer():
|
|||
if self.ssl_ciphers is not None:
|
||||
context.set_ciphers(self.ssl_ciphers)
|
||||
context.options = self.ssl_options
|
||||
context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password)
|
||||
context.load_cert_chain(
|
||||
certfile=self.cert, keyfile=self.key, password=self.key_password
|
||||
)
|
||||
if self.verify_client:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
if self.cafile:
|
||||
context.load_verify_locations(cafile=self.cafile)
|
||||
else:
|
||||
context.set_default_verify_paths()
|
||||
retsock = context.wrap_socket(
|
||||
sock,
|
||||
server_side=True)
|
||||
retsock = context.wrap_socket(sock, server_side=True)
|
||||
except ssl.SSLError:
|
||||
_, x, _ = sys.exc_info()
|
||||
if x.args[0] == ssl.SSL_ERROR_EOF:
|
||||
|
|
@ -618,28 +670,27 @@ class WebSockifyServer():
|
|||
#
|
||||
|
||||
def msg(self, *args, **kwargs):
|
||||
""" Output message as info """
|
||||
"""Output message as info"""
|
||||
self.logger.log(logging.INFO, *args, **kwargs)
|
||||
|
||||
def vmsg(self, *args, **kwargs):
|
||||
""" Same as msg() but as debug. """
|
||||
"""Same as msg() but as debug."""
|
||||
self.logger.log(logging.DEBUG, *args, **kwargs)
|
||||
|
||||
def warn(self, *args, **kwargs):
|
||||
""" Same as msg() but as warning. """
|
||||
"""Same as msg() but as warning."""
|
||||
self.logger.log(logging.WARN, *args, **kwargs)
|
||||
|
||||
|
||||
#
|
||||
# Events that can/should be overridden in sub-classes
|
||||
#
|
||||
def started(self):
|
||||
""" Called after WebSockets startup """
|
||||
"""Called after WebSockets startup"""
|
||||
self.vmsg("WebSockets server started")
|
||||
|
||||
def poll(self):
|
||||
""" Run periodically while waiting for connections. """
|
||||
#self.vmsg("Running poll()")
|
||||
"""Run periodically while waiting for connections."""
|
||||
# self.vmsg("Running poll()")
|
||||
pass
|
||||
|
||||
def terminate(self):
|
||||
|
|
@ -661,7 +712,7 @@ class WebSockifyServer():
|
|||
while result[0]:
|
||||
self.vmsg("Reaped child process %s" % result[0])
|
||||
result = os.waitpid(-1, os.WNOHANG)
|
||||
except (OSError):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def do_SIGINT(self, sig, stack):
|
||||
|
|
@ -675,7 +726,7 @@ class WebSockifyServer():
|
|||
self.terminate()
|
||||
|
||||
def top_new_client(self, startsock, address):
|
||||
""" Do something with a WebSockets client connection. """
|
||||
"""Do something with a WebSockets client connection."""
|
||||
# handler process
|
||||
client = None
|
||||
try:
|
||||
|
|
@ -693,7 +744,6 @@ class WebSockifyServer():
|
|||
self.msg("handler exception: %s" % str(exc))
|
||||
self.vmsg("exception", exc_info=True)
|
||||
finally:
|
||||
|
||||
if client and client != startsock:
|
||||
# Close the SSL wrapped socket
|
||||
# Original socket closed by caller
|
||||
|
|
@ -721,19 +771,27 @@ class WebSockifyServer():
|
|||
|
||||
try:
|
||||
if self.listen_fd != None:
|
||||
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
|
||||
lsock = socket.fromfd(
|
||||
self.listen_fd, socket.AF_INET, socket.SOCK_STREAM
|
||||
)
|
||||
elif self.unix_listen != None:
|
||||
lsock = self.socket(host=None,
|
||||
unix_socket=self.unix_listen,
|
||||
unix_socket_mode=self.unix_listen_mode,
|
||||
unix_socket_listen=True)
|
||||
lsock = self.socket(
|
||||
host=None,
|
||||
unix_socket=self.unix_listen,
|
||||
unix_socket_mode=self.unix_listen_mode,
|
||||
unix_socket_listen=True,
|
||||
)
|
||||
else:
|
||||
lsock = self.socket(self.listen_host, self.listen_port, False,
|
||||
self.prefer_ipv6,
|
||||
tcp_keepalive=self.tcp_keepalive,
|
||||
tcp_keepcnt=self.tcp_keepcnt,
|
||||
tcp_keepidle=self.tcp_keepidle,
|
||||
tcp_keepintvl=self.tcp_keepintvl)
|
||||
lsock = self.socket(
|
||||
self.listen_host,
|
||||
self.listen_port,
|
||||
False,
|
||||
self.prefer_ipv6,
|
||||
tcp_keepalive=self.tcp_keepalive,
|
||||
tcp_keepcnt=self.tcp_keepcnt,
|
||||
tcp_keepidle=self.tcp_keepidle,
|
||||
tcp_keepintvl=self.tcp_keepintvl,
|
||||
)
|
||||
except OSError as e:
|
||||
self.msg("Openening socket failed: %s", str(e))
|
||||
self.vmsg("exception", exc_info=True)
|
||||
|
|
@ -751,13 +809,13 @@ class WebSockifyServer():
|
|||
signal.SIGINT: signal.getsignal(signal.SIGINT),
|
||||
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
|
||||
}
|
||||
if getattr(signal, 'SIGCHLD', None) is not None:
|
||||
if getattr(signal, "SIGCHLD", None) is not None:
|
||||
original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD)
|
||||
signal.signal(signal.SIGINT, self.do_SIGINT)
|
||||
signal.signal(signal.SIGTERM, self.do_SIGTERM)
|
||||
# make sure that _cleanup is called when children die
|
||||
# by calling active_children on SIGCHLD
|
||||
if getattr(signal, 'SIGCHLD', None) is not None:
|
||||
if getattr(signal, "SIGCHLD", None) is not None:
|
||||
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
|
||||
|
||||
last_active_time = self.launch_time
|
||||
|
|
@ -774,8 +832,7 @@ class WebSockifyServer():
|
|||
|
||||
time_elapsed = time.time() - self.launch_time
|
||||
if self.timeout and time_elapsed > self.timeout:
|
||||
self.msg('listener exit due to --timeout %s'
|
||||
% self.timeout)
|
||||
self.msg("listener exit due to --timeout %s" % self.timeout)
|
||||
break
|
||||
|
||||
if self.idle_timeout:
|
||||
|
|
@ -787,8 +844,10 @@ class WebSockifyServer():
|
|||
last_active_time = time.time()
|
||||
|
||||
if idle_time > self.idle_timeout and child_count == 0:
|
||||
self.msg('listener exit due to --idle-timeout %s'
|
||||
% self.idle_timeout)
|
||||
self.msg(
|
||||
"listener exit due to --idle-timeout %s"
|
||||
% self.idle_timeout
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
|
|
@ -799,16 +858,16 @@ class WebSockifyServer():
|
|||
startsock, address = lsock.accept()
|
||||
# Unix Socket will not report address (empty string), but address[0] is logged a bunch
|
||||
if self.unix_listen != None:
|
||||
address = [ self.unix_listen ]
|
||||
address = [self.unix_listen]
|
||||
else:
|
||||
continue
|
||||
except self.Terminate:
|
||||
raise
|
||||
except Exception:
|
||||
_, exc, _ = sys.exc_info()
|
||||
if hasattr(exc, 'errno'):
|
||||
if hasattr(exc, "errno"):
|
||||
err = exc.errno
|
||||
elif hasattr(exc, 'args'):
|
||||
elif hasattr(exc, "args"):
|
||||
err = exc.args[0]
|
||||
else:
|
||||
err = exc[0]
|
||||
|
|
@ -821,15 +880,14 @@ class WebSockifyServer():
|
|||
if self.run_once:
|
||||
# Run in same process if run_once
|
||||
self.top_new_client(startsock, address)
|
||||
if self.ws_connection :
|
||||
self.msg('%s: exiting due to --run-once'
|
||||
% address[0])
|
||||
if self.ws_connection:
|
||||
self.msg("%s: exiting due to --run-once" % address[0])
|
||||
break
|
||||
else:
|
||||
self.vmsg('%s: new handler Process' % address[0])
|
||||
self.vmsg("%s: new handler Process" % address[0])
|
||||
p = multiprocessing.Process(
|
||||
target=self.top_new_client,
|
||||
args=(startsock, address))
|
||||
target=self.top_new_client, args=(startsock, address)
|
||||
)
|
||||
p.start()
|
||||
# child will not return
|
||||
|
||||
|
|
@ -857,12 +915,11 @@ class WebSockifyServer():
|
|||
startsock.close()
|
||||
finally:
|
||||
# Close listen port
|
||||
self.vmsg("Closing socket listening at %s:%s",
|
||||
self.listen_host, self.listen_port)
|
||||
self.vmsg(
|
||||
"Closing socket listening at %s:%s", self.listen_host, self.listen_port
|
||||
)
|
||||
lsock.close()
|
||||
|
||||
# Restore signals
|
||||
for sig, func in original_signals.items():
|
||||
signal.signal(sig, func)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue