Merge branch 'flake8' of github.com:kajinamit/websockify

This commit is contained in:
Pierre Ossman 2025-07-03 16:13:08 +02:00
commit bdf1ebf24e
19 changed files with 448 additions and 349 deletions

19
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: Lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Update pip and setuptools
run: |
python -m pip install --upgrade pip
python -m pip install setuptools
- name: Install dependencies
run: |
python -m pip install flake8
- name: Lint with flake8
run: |
flake8

View File

@ -1,45 +1,46 @@
from setuptools import setup, find_packages from setuptools import setup
version = '0.13.0' version = '0.13.0'
name = 'websockify' name = 'websockify'
long_description = open("README.md").read() + "\n" + \ long_description = open("README.md").read() + "\n" + \
open("CHANGES.txt").read() + "\n" open("CHANGES.txt").read() + "\n"
setup(name=name, setup(
version=version, name=name,
description="Websockify.", version=version,
long_description=long_description, description="Websockify.",
long_description_content_type="text/markdown", long_description=long_description,
classifiers=[ long_description_content_type="text/markdown",
"Programming Language :: Python", classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.11",
], "Programming Language :: Python :: 3.12",
python_requires='>=3.6', ],
keywords='noVNC websockify', python_requires='>=3.6',
license='LGPLv3', keywords='noVNC websockify',
url="https://github.com/novnc/websockify", license='LGPLv3',
author="Joel Martin", url="https://github.com/novnc/websockify",
author_email="github@martintribe.org", author="Joel Martin",
author_email="github@martintribe.org",
packages=['websockify'], packages=['websockify'],
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'numpy', 'requests', 'numpy',
'requests',
'jwcrypto', 'jwcrypto',
'redis', 'redis',
], ],
zip_safe=False, zip_safe=False,
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'websockify = websockify.websocketproxy:websockify_init', 'websockify = websockify.websocketproxy:websockify_init',
] ]
}, },
) )

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# flake8: noqa: E402
''' '''
A WebSocket server that echos back whatever it receives from the client. A WebSocket server that echos back whatever it receives from the client.
Copyright 2010 Joel Martin Copyright 2010 Joel Martin
@ -10,9 +10,17 @@ openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import os, sys, select, optparse, logging import logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) import optparse
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler import os
import select
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyServer
from websockify.websockifyserver import WebSockifyRequestHandler
class WebSocketEcho(WebSockifyRequestHandler): class WebSocketEcho(WebSockifyRequestHandler):
""" """
@ -27,15 +35,17 @@ class WebSocketEcho(WebSockifyRequestHandler):
cqueue = [] cqueue = []
c_pend = 0 c_pend = 0
cpartial = "" cpartial = "" # noqa: F841
rlist = [self.request] rlist = [self.request]
while True: while True:
wlist = [] wlist = []
if cqueue or c_pend: wlist.append(self.request) if cqueue or c_pend:
wlist.append(self.request)
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if self.request in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
@ -50,20 +60,22 @@ class WebSocketEcho(WebSockifyRequestHandler):
if closed: if closed:
break break
if __name__ == '__main__': if __name__ == '__main__':
parser = optparse.OptionParser(usage="%prog [options] listen_port") parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true", parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic") help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem", parser.add_option("--cert", default="self.pem",
help="SSL certificate file") help="SSL certificate file")
parser.add_option("--key", default=None, parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)") help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true", parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections") help="disallow non-encrypted connections")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
try: try:
if len(args) != 1: raise ValueError if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
except ValueError: except ValueError:
parser.error("Invalid arguments") parser.error("Invalid arguments")
@ -73,4 +85,3 @@ if __name__ == '__main__':
opts.web = "." opts.web = "."
server = WebSockifyServer(WebSocketEcho, **opts.__dict__) server = WebSockifyServer(WebSocketEcho, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -1,13 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# flake8: noqa: E402
import os
import sys
import optparse import optparse
import os
import select import select
import sys
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocket, \
WebSocketWantReadError, WebSocketWantWriteError from websockify.websocket import WebSocket
from websockify.websocket import WebSocketWantReadError
from websockify.websocket import WebSocketWantWriteError
parser = optparse.OptionParser(usage="%prog URL") parser = optparse.OptionParser(usage="%prog URL")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
@ -22,6 +25,7 @@ print("Connecting to %s..." % URL)
sock.connect(URL) sock.connect(URL)
print("Connected.") print("Connected.")
def send(msg): def send(msg):
while True: while True:
try: try:
@ -30,11 +34,14 @@ def send(msg):
except WebSocketWantReadError: except WebSocketWantReadError:
msg = '' msg = ''
ins, outs, excepts = select.select([sock], [], []) ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError: except WebSocketWantWriteError:
msg = '' msg = ''
ins, outs, excepts = select.select([], [sock], []) ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
def read(): def read():
while True: while True:
@ -42,10 +49,13 @@ def read():
return sock.recvmsg() return sock.recvmsg()
except WebSocketWantReadError: except WebSocketWantReadError:
ins, outs, excepts = select.select([sock], [], []) ins, outs, excepts = select.select([sock], [], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
except WebSocketWantWriteError: except WebSocketWantWriteError:
ins, outs, excepts = select.select([], [sock], []) ins, outs, excepts = select.select([], [sock], [])
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
counter = 1 counter = 1
while True: while True:
@ -56,7 +66,8 @@ while True:
while True: while True:
ins, outs, excepts = select.select([sock], [], [], 1.0) ins, outs, excepts = select.select([sock], [], [], 1.0)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if ins == []: if ins == []:
break break

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# flake8: noqa: E402
''' '''
WebSocket server-side load test program. Sends and receives traffic WebSocket server-side load test program. Sends and receives traffic
@ -6,9 +7,19 @@ that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted. given a sequence number. Any errors are reported and counted.
''' '''
import sys, os, select, random, time, optparse, logging import logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) import optparse
from websockify.websockifyserver import WebSockifyServer, WebSockifyRequestHandler import os
import random
import select
import sys
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from websockify.websockifyserver import WebSockifyRequestHandler
from websockify.websockifyserver import WebSockifyServer
class WebSocketLoadServer(WebSockifyServer): class WebSocketLoadServer(WebSockifyServer):
@ -26,7 +37,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
max_packet_size = 10000 max_packet_size = 10000
def new_websocket_client(self): def new_websocket_client(self):
print "Prepopulating random array" print("Prepopulating random array")
self.rand_array = [] self.rand_array = []
for i in range(0, self.max_packet_size): for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9)) self.rand_array.append(random.randint(0, 9))
@ -37,19 +48,18 @@ class WebSocketLoad(WebSockifyRequestHandler):
self.responder(self.request) self.responder(self.request)
print "accumulated errors:", self.errors print("accumulated errors:", self.errors)
self.errors = 0 self.errors = 0
def responder(self, client): def responder(self, client):
c_pend = 0 c_pend = 0
cqueue = []
cpartial = ""
socks = [client] socks = [client]
last_send = time.time() * 1000 last_send = time.time() * 1000
while True: while True:
ins, outs, excepts = select.select(socks, socks, socks, 1) ins, outs, excepts = select.select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if client in ins: if client in ins:
frames, closed = self.recv_frames() frames, closed = self.recv_frames()
@ -57,7 +67,7 @@ class WebSocketLoad(WebSockifyRequestHandler):
err = self.check(frames) err = self.check(frames)
if err: if err:
self.errors = self.errors + 1 self.errors = self.errors + 1
print err print(err)
if closed: if closed:
break break
@ -73,19 +83,14 @@ class WebSocketLoad(WebSockifyRequestHandler):
def generate(self): def generate(self):
length = random.randint(10, self.max_packet_size) length = random.randint(10, self.max_packet_size)
numlist = self.rand_array[self.max_packet_size-length:] numlist = self.rand_array[self.max_packet_size - length:]
# Error in length
#numlist.append(5)
chksum = sum(numlist) chksum = sum(numlist)
# Error in checksum nums = "".join([str(n) for n in numlist])
#numlist[0] = 5
nums = "".join( [str(n) for n in numlist] )
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums) data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
self.send_cnt += 1 self.send_cnt += 1
return data return data
def check(self, frames): def check(self, frames):
err = "" err = ""
@ -102,11 +107,11 @@ class WebSocketLoad(WebSockifyRequestHandler):
try: try:
cnt, length, chksum, nums = data[1:-1].split(':') cnt, length, chksum, nums = data[1:-1].split(':')
cnt = int(cnt) cnt = int(cnt)
length = int(length) length = int(length)
chksum = int(chksum) chksum = int(chksum)
except ValueError: except ValueError:
print "\n<BOF>" + repr(data) + "<EOF>" print("\n<BOF>" + repr(data) + "<EOF>")
err += "Invalid data format\n" err += "Invalid data format\n"
continue continue
@ -138,20 +143,22 @@ class WebSocketLoad(WebSockifyRequestHandler):
if __name__ == '__main__': if __name__ == '__main__':
parser = optparse.OptionParser(usage="%prog [options] listen_port") parser = optparse.OptionParser(usage="%prog [options] listen_port")
parser.add_option("--verbose", "-v", action="store_true", parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic") help="verbose messages and per frame traffic")
parser.add_option("--cert", default="self.pem", parser.add_option("--cert", default="self.pem",
help="SSL certificate file") help="SSL certificate file")
parser.add_option("--key", default=None, parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)") help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true", parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections") help="disallow non-encrypted connections")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
try: try:
if len(args) != 1: raise ValueError if len(args) != 1:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
if len(args) not in [1,2]: raise ValueError if len(args) not in [1, 2]:
raise ValueError
opts.listen_port = int(args[0]) opts.listen_port = int(args[0])
if len(args) == 2: if len(args) == 2:
opts.delay = int(args[1]) opts.delay = int(args[1])
@ -165,4 +172,3 @@ if __name__ == '__main__':
opts.web = "." opts.web = "."
server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -7,8 +7,14 @@ import unittest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from jwcrypto import jwt, jwk from jwcrypto import jwt, jwk
try:
import redis
except ImportError:
redis = None
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis
class ParseSourceArgumentsTestCase(unittest.TestCase): class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self): def test_parameterized(self):
params = [ params = [
@ -31,6 +37,7 @@ class ParseSourceArgumentsTestCase(unittest.TestCase):
for src, args in params: for src, args in params:
self.assertEqual(args, parse_source_args(src)) self.assertEqual(args, parse_source_args(src))
class ReadOnlyTokenFileTestCase(unittest.TestCase): class ReadOnlyTokenFileTestCase(unittest.TestCase):
def test_empty(self): def test_empty(self):
mock_source_file = MagicMock() mock_source_file = MagicMock()
@ -143,7 +150,7 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200})
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 150 mock_time.return_value = 150
@ -160,7 +167,7 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200})
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 50 mock_time.return_value = 50
@ -175,7 +182,7 @@ class JWSTokenTestCase(unittest.TestCase):
key = jwk.JWK() key = jwk.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
key.import_from_pem(private_key) key.import_from_pem(private_key)
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200 }) jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port", 'nbf': 100, 'exp': 200})
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
mock_time.return_value = 250 mock_time.return_value = 250
@ -188,7 +195,7 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK() key = jwk.JWK()
key.import_key(kty="oct",k=secret) key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
@ -203,7 +210,7 @@ class JWSTokenTestCase(unittest.TestCase):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwk.JWK() key = jwk.JWK()
key.import_key(kty="oct",k=secret) key.import_key(kty="oct", k=secret)
jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT({"alg": "HS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
@ -223,7 +230,7 @@ class JWSTokenTestCase(unittest.TestCase):
jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"}) jwt_token = jwt.JWT({"alg": "RS256"}, {'host': "remote_host", 'port': "remote_port"})
jwt_token.make_signed_token(private_key) jwt_token.make_signed_token(private_key)
jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"}, jwe_token = jwt.JWT(header={"alg": "RSA-OAEP", "enc": "A256CBC-HS512"},
claims=jwt_token.serialize()) claims=jwt_token.serialize())
jwe_token.make_encrypted_token(public_key) jwe_token.make_encrypted_token(public_key)
result = plugin.lookup(jwt_token.serialize()) result = plugin.lookup(jwt_token.serialize())
@ -232,11 +239,10 @@ class JWSTokenTestCase(unittest.TestCase):
self.assertEqual(result[0], "remote_host") self.assertEqual(result[0], "remote_host")
self.assertEqual(result[1], "remote_port") self.assertEqual(result[1], "remote_port")
class TokenRedisTestCase(unittest.TestCase): class TokenRedisTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
try: if redis is None:
import redis
except ImportError:
patcher = patch.dict(sys.modules, {'redis': MagicMock()}) patcher = patch.dict(sys.modules, {'redis': MagicMock()})
patcher.start() patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)

View File

@ -18,6 +18,7 @@
import unittest import unittest
from websockify import websocket from websockify import websocket
class FakeSocket: class FakeSocket:
def __init__(self): def __init__(self):
self.data = b'' self.data = b''
@ -26,6 +27,7 @@ class FakeSocket:
self.data += buf self.data += buf
return len(buf) return len(buf)
class AcceptTestCase(unittest.TestCase): class AcceptTestCase(unittest.TestCase):
def test_success(self): def test_success(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
@ -81,7 +83,7 @@ class AcceptTestCase(unittest.TestCase):
ws.accept(sock, {'upgrade': 'websocket', ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13', 'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'}) 'Sec-WebSocket-Protocol': 'foobar,gazonk'})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data) self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data)
@ -101,9 +103,9 @@ class AcceptTestCase(unittest.TestCase):
sock, {'upgrade': 'websocket', sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13', 'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'}) 'Sec-WebSocket-Protocol': 'foobar,gazonk'})
def test_protocol(self): def test_unsupported_protocol(self):
class ProtoSocket(websocket.WebSocket): class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol): def select_subprotocol(self, protocol):
return 'oddball' return 'oddball'
@ -114,7 +116,8 @@ class AcceptTestCase(unittest.TestCase):
sock, {'upgrade': 'websocket', sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13', 'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'}) 'Sec-WebSocket-Protocol': 'foobar,gazonk'})
class PingPongTest(unittest.TestCase): class PingPongTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -142,6 +145,7 @@ class PingPongTest(unittest.TestCase):
self.ws.pong(b'foo') self.ws.pong(b'foo')
self.assertEqual(self.sock.data, b'\x8a\x03foo') self.assertEqual(self.sock.data, b'\x8a\x03foo')
class HyBiEncodeDecodeTestCase(unittest.TestCase): class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self): def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'

View File

@ -16,8 +16,6 @@
""" Unit tests for websocketproxy """ """ Unit tests for websocketproxy """
import sys
import unittest
import unittest import unittest
import socket import socket
from io import StringIO from io import StringIO
@ -58,6 +56,7 @@ class FakeServer:
self.ssl_target = None self.ssl_target = None
self.unix_target = None self.unix_target = None
class ProxyRequestHandlerTestCase(unittest.TestCase): class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -128,4 +127,3 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.server.target_host = "someotherhost" self.handler.server.target_host = "someotherhost"
self.handler.auth_connection() self.handler.auth_connection()

View File

@ -66,4 +66,3 @@ class HttpWebSocketTest(unittest.TestCase):
# Then # Then
req_obj.end_headers.assert_called_once_with() req_obj.end_headers.assert_called_once_with()

View File

@ -17,18 +17,12 @@
""" Unit tests for websockifyserver """ """ Unit tests for websockifyserver """
import errno import errno
import os import os
import logging
import select
import shutil
import socket import socket
import ssl import ssl
from unittest.mock import patch, MagicMock, ANY from unittest.mock import patch
import sys import sys
import tempfile import tempfile
import unittest import unittest
import socket
import signal
from http.server import BaseHTTPRequestHandler
from io import StringIO from io import StringIO
from io import BytesIO from io import BytesIO
@ -237,6 +231,7 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_do_handshake_no_ssl(self): def test_do_handshake_no_ssl(self):
class FakeHandler: class FakeHandler:
CALLED = False CALLED = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
type(self).CALLED = True type(self).CALLED = True
@ -292,12 +287,16 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self.options = 0 self.options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF) raise ssl.SSLError(ssl.SSL_ERROR_EOF)
@ -314,7 +313,7 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
server = self._get_server(handler_class=FakeHandler, daemon=True, server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_ciphers=test_ciphers) idle_timeout=1, ssl_ciphers=test_ciphers)
sock = FakeSocket(b"\x16some ssl data") sock = FakeSocket(b"\x16some ssl data")
@ -323,17 +322,23 @@ class WebSockifyServerTestCase(unittest.TestCase):
class fake_create_default_context(): class fake_create_default_context():
CIPHERS = '' CIPHERS = ''
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self.options = 0 self.options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
pass pass
def set_ciphers(self, ciphers_to_set): def set_ciphers(self, ciphers_to_set):
fake_create_default_context.CIPHERS = ciphers_to_set fake_create_default_context.CIPHERS = ciphers_to_set
@ -349,7 +354,7 @@ class WebSockifyServerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
server = self._get_server(handler_class=FakeHandler, daemon=True, server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_options=test_options) idle_timeout=1, ssl_options=test_options)
sock = FakeSocket(b"\x16some ssl data") sock = FakeSocket(b"\x16some ssl data")
@ -358,19 +363,26 @@ class WebSockifyServerTestCase(unittest.TestCase):
class fake_create_default_context: class fake_create_default_context:
OPTIONS = 0 OPTIONS = 0
def __init__(self, purpose): def __init__(self, purpose):
self.verify_mode = None self.verify_mode = None
self._options = 0 self._options = 0
def load_cert_chain(self, certfile, keyfile, password): def load_cert_chain(self, certfile, keyfile, password):
pass pass
def set_default_verify_paths(self): def set_default_verify_paths(self):
pass pass
def load_verify_locations(self, cafile): def load_verify_locations(self, cafile):
pass pass
def wrap_socket(self, *args, **kwargs): def wrap_socket(self, *args, **kwargs):
pass pass
def get_options(self): def get_options(self):
return self._options return self._options
def set_options(self, val): def set_options(self, val):
fake_create_default_context.OPTIONS = val fake_create_default_context.OPTIONS = val
options = property(get_options, set_options) options = property(get_options, set_options)
@ -386,7 +398,6 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_error(self): def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error") raise Exception("fake error")
@ -398,7 +409,6 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_keyboardinterrupt(self): def test_start_server_keyboardinterrupt(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt raise KeyboardInterrupt
@ -410,7 +420,6 @@ class WebSockifyServerTestCase(unittest.TestCase):
def test_start_server_systemexit(self): def test_start_server_systemexit(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None): def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit() sys.exit()

13
tox.ini
View File

@ -12,6 +12,13 @@ deps = -r{toxinidir}/test-requirements.txt
# At some point we should enable this since tox expects it to exist but # At some point we should enable this since tox expects it to exist but
# the code will need pep8ising first. # the code will need pep8ising first.
#[testenv:pep8] [testenv:pep8]
#commands = flake8 commands = flake8
#dep = flake8 deps = flake8
[flake8]
max-line-length = 160
# E129 visually indented line with same indent as next logical line
# W503 line break before binary operator
# W504 line break after binary operator
ignore = E129,W503,W504

View File

@ -1,2 +1,2 @@
from websockify.websocket import * from websockify.websocket import * # noqa: F401,F403
from websockify.websocketproxy import * from websockify.websocketproxy import * # noqa: F401,F403

View File

@ -76,6 +76,7 @@ class BasicHTTPAuth():
raise AuthenticationError(response_code=401, raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
class ExpectOrigin(): class ExpectOrigin():
def __init__(self, src=None): def __init__(self, src=None):
if src is None: if src is None:
@ -88,6 +89,7 @@ class ExpectOrigin():
if origin is None or origin not in self.source: if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin) raise InvalidOriginError(expected=self.source, actual=origin)
class ClientCertCNAuth(): class ClientCertCNAuth():
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names.""" """Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""

View File

@ -1,4 +1,7 @@
import logging.handlers as handlers, socket, os, time import logging.handlers as handlers
import os
import socket
import time
class WebsockifySysLogHandler(handlers.SysLogHandler): class WebsockifySysLogHandler(handlers.SysLogHandler):
@ -13,14 +16,11 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
_legacy = False _legacy = False
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ' _timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ'
_max_hostname = 255 _max_hostname = 255
_max_ident = 24 #safer for old daemons _max_ident = 24 # safer for old daemons
_send_length = False _send_length = False
_tail = '\n' _tail = '\n'
ident = None ident = None
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT), def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT),
facility=handlers.SysLogHandler.LOG_USER, facility=handlers.SysLogHandler.LOG_USER,
socktype=None, ident=None, legacy=False): socktype=None, ident=None, legacy=False):
@ -46,7 +46,6 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
super().__init__(address, facility, socktype) super().__init__(address, facility, socktype)
def emit(self, record): def emit(self, record):
""" """
Emit a record. Emit a record.
@ -58,13 +57,13 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
try: try:
# Gather info. # Gather info.
text = self.format(record).replace(self._tail, ' ') text = self.format(record).replace(self._tail, ' ')
if not text: # nothing to log if not text: # nothing to log
return return
pri = self.encodePriority(self.facility, pri = self.encodePriority(self.facility,
self.mapPriority(record.levelname)) self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime()); timestamp = time.strftime(self._timestamp_fmt, time.gmtime())
hostname = socket.gethostname()[:self._max_hostname] hostname = socket.gethostname()[:self._max_hostname]
@ -73,7 +72,7 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
else: else:
ident = '' ident = ''
pid = os.getpid() # shouldn't need truncation pid = os.getpid() # shouldn't need truncation
# Format the header. # Format the header.
head = { head = {
@ -112,7 +111,5 @@ class WebsockifySysLogHandler(handlers.SysLogHandler):
else: else:
self.socket.sendall(msg) self.socket.sendall(msg)
except (KeyboardInterrupt, SystemExit): except Exception:
raise
except:
self.handleError(record) self.handleError(record)

View File

@ -5,6 +5,11 @@ import re
import json import json
from pathlib import Path from pathlib import Path
try:
import redis
except ImportError:
redis = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SOURCE_SPLIT_REGEX = re.compile( _SOURCE_SPLIT_REGEX = re.compile(
@ -84,6 +89,7 @@ class TokenFile(ReadOnlyTokenFile):
return super().lookup(token) return super().lookup(token)
class TokenFileName(BasePlugin): class TokenFileName(BasePlugin):
# source is a directory # source is a directory
# token is filename # token is filename
@ -113,8 +119,8 @@ class BaseTokenAPI(BasePlugin):
def process_result(self, resp): def process_result(self, resp):
host, port = resp.text.split(':') host, port = resp.text.split(':')
port = port.encode('ascii','ignore') port = port.encode('ascii', 'ignore')
return [ host, port ] return [host, port]
def lookup(self, token): def lookup(self, token):
import requests import requests
@ -156,10 +162,10 @@ class JWTTokenApi(BasePlugin):
try: try:
key.import_from_pem(key_data) key.import_from_pem(key_data)
except: except Exception:
try: try:
key.import_key(k=key_data.decode('utf-8'),kty='oct') key.import_key(k=key_data.decode('utf-8'), kty='oct')
except: except Exception:
logger.error('Failed to correctly parse key data!') logger.error('Failed to correctly parse key data!')
return None return None
@ -255,9 +261,7 @@ class TokenRedis(BasePlugin):
pip install redis pip install redis
""" """
def __init__(self, src): def __init__(self, src):
try: if redis is None:
import redis
except ImportError:
logger.error("Unable to load redis module") logger.error("Unable to load redis module")
sys.exit() sys.exit()
# Default values # Default values
@ -305,7 +309,7 @@ class TokenRedis(BasePlugin):
self._namespace += ":" self._namespace += ":"
logger.info("TokenRedis backend initialized (%s:%s)" % logger.info("TokenRedis backend initialized (%s:%s)" %
(self._server, self._port)) (self._server, self._port))
except ValueError: except ValueError:
logger.error("The provided --token-source='%s' is not in the " logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" % "expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
@ -313,9 +317,7 @@ class TokenRedis(BasePlugin):
sys.exit() sys.exit()
def lookup(self, token): def lookup(self, token):
try: if redis is None:
import redis
except ImportError:
logger.error("package redis not found, are you sure you've installed them correctly?") logger.error("package redis not found, are you sure you've installed them correctly?")
sys.exit() sys.exit()
@ -372,7 +374,7 @@ class UnixDomainSocketDirectory(BasePlugin):
if not stat.S_ISSOCK(uds_path.stat().st_mode): if not stat.S_ISSOCK(uds_path.stat().st_mode):
return None return None
return [ 'unix_socket', uds_path ] return ['unix_socket', uds_path]
except Exception as e: except Exception as e:
logger.error("Error finding unix domain socket: %s" % str(e)) logger.error("Error finding unix domain socket: %s" % str(e))
return None return None

View File

@ -31,11 +31,15 @@ except ImportError:
warnings.warn("no 'numpy' module, HyBi protocol will be slower") warnings.warn("no 'numpy' module, HyBi protocol will be slower")
numpy = None numpy = None
class WebSocketWantReadError(ssl.SSLWantReadError): class WebSocketWantReadError(ssl.SSLWantReadError):
pass pass
class WebSocketWantWriteError(ssl.SSLWantWriteError): class WebSocketWantWriteError(ssl.SSLWantWriteError):
pass pass
class WebSocket: class WebSocket:
"""WebSocket protocol socket like class. """WebSocket protocol socket like class.
@ -118,7 +122,7 @@ class WebSocket:
connect() must retain the same arguments. connect() must retain the same arguments.
""" """
self.client = True; self.client = True
uri = urlparse(uri) uri = urlparse(uri)
@ -206,7 +210,7 @@ class WebSocket:
accept = headers.get('Sec-WebSocket-Accept') accept = headers.get('Sec-WebSocket-Accept')
if accept is None: if accept is None:
raise Exception("Missing Sec-WebSocket-Accept header"); raise Exception("Missing Sec-WebSocket-Accept header")
expected = sha1((self._key + self.GUID).encode("ascii")).digest() expected = sha1((self._key + self.GUID).encode("ascii")).digest()
expected = b64encode(expected).decode("ascii") expected = b64encode(expected).decode("ascii")
@ -214,7 +218,7 @@ class WebSocket:
del self._key del self._key
if accept != expected: if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header"); raise Exception("Invalid Sec-WebSocket-Accept header")
self.protocol = headers.get('Sec-WebSocket-Protocol') self.protocol = headers.get('Sec-WebSocket-Protocol')
if len(protocols) == 0: if len(protocols) == 0:
@ -258,7 +262,7 @@ class WebSocket:
ver = headers.get('Sec-WebSocket-Version') ver = headers.get('Sec-WebSocket-Version')
if ver is None: if ver is None:
raise Exception("Missing Sec-WebSocket-Version header"); raise Exception("Missing Sec-WebSocket-Version header")
# HyBi-07 report version 7 # HyBi-07 report version 7
# HyBi-08 - HyBi-12 report version 8 # HyBi-08 - HyBi-12 report version 8
@ -270,7 +274,7 @@ class WebSocket:
key = headers.get('Sec-WebSocket-Key') key = headers.get('Sec-WebSocket-Key')
if key is None: if key is None:
raise Exception("Missing Sec-WebSocket-Key header"); raise Exception("Missing Sec-WebSocket-Key header")
# Generate the hash value for the accept header # Generate the hash value for the accept header
accept = sha1((key + self.GUID).encode("ascii")).digest() accept = sha1((key + self.GUID).encode("ascii")).digest()
@ -753,25 +757,22 @@ class WebSocket:
# Unmask a frame # Unmask a frame
if numpy: if numpy:
plen = len(buf) plen = len(buf)
pstart = 0
pend = plen
b = c = b'' b = c = b''
if plen >= 4: if plen >= 4:
dtype=numpy.dtype('<u4') dtype = numpy.dtype('<u4')
if sys.byteorder == 'big': if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>') dtype = dtype.newbyteorder('>')
mask = numpy.frombuffer(mask, dtype, count=1) mask = numpy.frombuffer(mask, dtype, count=1)
data = numpy.frombuffer(buf, dtype, count=int(plen / 4)) data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
#b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tobytes() b = numpy.bitwise_xor(data, mask).tobytes()
if plen % 4: if plen % 4:
dtype=numpy.dtype('B') dtype = numpy.dtype('B')
if sys.byteorder == 'big': if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>') dtype = dtype.newbyteorder('>')
mask = numpy.frombuffer(mask, dtype, count=(plen % 4)) mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
data = numpy.frombuffer(buf, dtype, data = numpy.frombuffer(buf, dtype,
offset=plen - (plen % 4), count=(plen % 4)) offset=plen - (plen % 4), count=(plen % 4))
c = numpy.bitwise_xor(data, mask).tobytes() c = numpy.bitwise_xor(data, mask).tobytes()
return b + c return b + c
else: else:
@ -825,11 +826,11 @@ class WebSocket:
'payload' : decoded_buffer} 'payload' : decoded_buffer}
""" """
f = {'fin' : 0, f = {'fin': 0,
'opcode' : 0, 'opcode': 0,
'masked' : False, 'masked': False,
'length' : 0, 'length': 0,
'payload' : None} 'payload': None}
blen = len(buf) blen = len(buf)
hlen = 2 hlen = 2
@ -867,10 +868,9 @@ class WebSocket:
if f['masked']: if f['masked']:
# unmask payload # unmask payload
mask_key = buf[hlen-4:hlen] mask_key = buf[hlen - 4:hlen]
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key) f['payload'] = self._unmask(buf[hlen:(hlen + length)], mask_key)
else: else:
f['payload'] = buf[hlen:(hlen+length)] f['payload'] = buf[hlen:(hlen + length)]
return f return f

View File

@ -11,14 +11,26 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat import errno
from socketserver import ThreadingMixIn
from http.server import HTTPServer from http.server import HTTPServer
import logging
import optparse
import os
import select import select
import signal
import socket
from socketserver import ThreadingMixIn
import ssl
import stat
import subprocess
import sys
import time
from urllib.parse import parse_qs
from urllib.parse import urlparse
from websockify import websockifyserver from websockify import websockifyserver
from websockify import auth_plugins as auth from websockify import auth_plugins as auth
from urllib.parse import parse_qs, urlparse
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler): class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
@ -62,7 +74,7 @@ Traffic Legend:
# clear out any existing SSL_ headers that the client might # clear out any existing SSL_ headers that the client might
# have maliciously set # have maliciously set
ssl_headers = [ h for h in self.headers if h.startswith('SSL_') ] ssl_headers = [h for h in self.headers if h.startswith('SSL_')]
for h in ssl_headers: for h in ssl_headers:
del self.headers[h] del self.headers[h]
@ -101,7 +113,7 @@ Traffic Legend:
msg = "connecting to unix socket: %s" % self.server.unix_target msg = "connecting to unix socket: %s" % self.server.unix_target
else: else:
msg = "connecting to: %s:%s" % ( msg = "connecting to: %s:%s" % (
self.server.target_host, self.server.target_port) self.server.target_host, self.server.target_port)
if self.server.ssl_target: if self.server.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
@ -109,10 +121,10 @@ Traffic Legend:
try: try:
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host, tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host,
self.server.target_port, self.server.target_port,
connect=True, connect=True,
use_ssl=self.server.ssl_target, use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target) unix_socket=self.server.unix_target)
except Exception as e: except Exception as e:
self.log_message("Failed to connect to %s:%s: %s", self.log_message("Failed to connect to %s:%s: %s",
self.server.target_host, self.server.target_port, e) self.server.target_host, self.server.target_port, e)
@ -135,7 +147,7 @@ Traffic Legend:
tsock.close() tsock.close()
if self.verbose: if self.verbose:
self.log_message("%s:%s: Closed target", self.log_message("%s:%s: Closed target",
self.server.target_host, self.server.target_port) self.server.target_host, self.server.target_port)
def get_target(self, target_plugin): def get_target(self, target_plugin):
""" """
@ -159,7 +171,7 @@ Traffic Legend:
else: else:
# Extract the token parameter from url # Extract the token parameter from url
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
if 'token' in args and len(args['token']): if 'token' in args and len(args['token']):
token = args['token'][0].rstrip('\n') token = args['token'][0].rstrip('\n')
@ -200,8 +212,10 @@ Traffic Legend:
self.heartbeat = now + self.server.heartbeat self.heartbeat = now + self.server.heartbeat
self.send_ping() self.send_ping()
if tqueue: wlist.append(target) if tqueue:
if cqueue or c_pend: wlist.append(self.request) wlist.append(target)
if cqueue or c_pend:
wlist.append(self.request)
try: try:
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
except OSError: except OSError:
@ -216,7 +230,8 @@ Traffic Legend:
else: else:
continue continue
if excepts: raise Exception("Socket exception") if excepts:
raise Exception("Socket exception")
if self.request in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
@ -245,10 +260,9 @@ Traffic Legend:
# TODO: What about blocking on client socket? # TODO: What about blocking on client socket?
if self.verbose: if self.verbose:
self.log_message("%s:%s: Client closed connection", self.log_message("%s:%s: Client closed connection",
self.server.target_host, self.server.target_port) self.server.target_host, self.server.target_port)
raise self.CClose(closed['code'], closed['reason']) raise self.CClose(closed['code'], closed['reason'])
if target in outs: if target in outs:
# Send queued client data to the target # Send queued client data to the target
dat = tqueue.pop(0) dat = tqueue.pop(0)
@ -260,7 +274,6 @@ Traffic Legend:
tqueue.insert(0, dat[sent:]) tqueue.insert(0, dat[sent:])
self.print_traffic(".>") self.print_traffic(".>")
if target in ins: if target in ins:
# Receive target data, encode it and queue for client # Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size) buf = target.recv(self.buffer_size)
@ -270,19 +283,20 @@ Traffic Legend:
# Send queued target data to the client # Send queued target data to the client
if len(cqueue) != 0: if len(cqueue) != 0:
c_pend = True c_pend = True
while(c_pend): while (c_pend):
c_pend = self.send_frames(cqueue) c_pend = self.send_frames(cqueue)
cqueue = [] cqueue = []
if self.verbose: if self.verbose:
self.log_message("%s:%s: Target closed connection", self.log_message("%s:%s: Target closed connection",
self.server.target_host, self.server.target_port) self.server.target_host, self.server.target_port)
raise self.CClose(1000, "Target closed") raise self.CClose(1000, "Target closed")
cqueue.append(buf) cqueue.append(buf)
self.print_traffic("{") self.print_traffic("{")
class WebSocketProxy(websockifyserver.WebSockifyServer): class WebSocketProxy(websockifyserver.WebSockifyServer):
""" """
Proxy traffic to and from a WebSockets client to a normal TCP Proxy traffic to and from a WebSockets client to a normal TCP
@ -293,20 +307,20 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options # Save off proxy specific options
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None) self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None) self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None) self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop('ssl_target', None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop('heartbeat', None)
self.token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop('token_plugin', None)
self.host_token = kwargs.pop('host_token', None) self.host_token = kwargs.pop('host_token', None)
self.auth_plugin = kwargs.pop('auth_plugin', None) self.auth_plugin = kwargs.pop('auth_plugin', None)
# Last 3 timestamps command was run # Last 3 timestamps command was run
self.wrap_times = [0, 0, 0] self.wrap_times = [0, 0, 0]
if self.wrap_cmd: if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0]) wsdir = os.path.dirname(sys.argv[0])
@ -334,7 +348,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
sock.close() sock.close()
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist # Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ]) ld_preloads = filter(None, [self.rebinder, os.environ.get("LD_PRELOAD", None)])
os.environ.update({ os.environ.update({
"LD_PRELOAD": os.pathsep.join(ld_preloads), "LD_PRELOAD": os.pathsep.join(ld_preloads),
@ -348,7 +362,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.wrap_times.append(time.time()) self.wrap_times.append(time.time())
self.wrap_times.pop(0) self.wrap_times.pop(0)
self.cmd = subprocess.Popen( self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.spawn_message = True self.spawn_message = True
def started(self): def started(self):
@ -364,7 +378,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
else: else:
dst_string = "%s:%s" % (self.target_host, self.target_port) dst_string = "%s:%s" % (self.target_host, self.target_port)
if self.listen_fd != None: if self.listen_fd is not None:
src_string = "inetd" src_string = "inetd"
else: else:
src_string = "%s:%s" % (self.listen_host, self.listen_port) src_string = "%s:%s" % (self.listen_host, self.listen_port)
@ -389,11 +403,11 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
if self.wrap_cmd and self.cmd: if self.wrap_cmd and self.cmd:
ret = self.cmd.poll() ret = self.cmd.poll()
if ret != None: if ret is not None:
self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret) self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
self.cmd = None self.cmd = None
if self.wrap_cmd and self.cmd == None: if self.wrap_cmd and self.cmd is None:
# Response to wrapped command being gone # Response to wrapped command being gone
if self.wrap_mode == "ignore": if self.wrap_mode == "ignore":
pass pass
@ -401,7 +415,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
sys.exit(ret) sys.exit(ret)
elif self.wrap_mode == "respawn": elif self.wrap_mode == "respawn":
now = time.time() now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times) avg = sum(self.wrap_times) / len(self.wrap_times)
if (now - avg) < 10: if (now - avg) < 10:
# 3 times in the last 10 seconds # 3 times in the last 10 seconds
if self.spawn_message: if self.spawn_message:
@ -427,6 +441,7 @@ SSL_OPTIONS = {
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2, ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2,
} }
def select_ssl_version(version): def select_ssl_version(version):
"""Returns SSL options for the most secure TSL version available on this """Returns SSL options for the most secure TSL version available on this
Python version""" Python version"""
@ -444,6 +459,7 @@ def select_ssl_version(version):
return SSL_OPTIONS[fallback] return SSL_OPTIONS[fallback]
def websockify_init(): def websockify_init():
# Setup basic logging to stderr. # Setup basic logging to stderr.
stderr_handler = logging.StreamHandler() stderr_handler = logging.StreamHandler()
@ -465,72 +481,72 @@ def websockify_init():
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage) parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true", parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages") help="verbose messages")
parser.add_option("--traffic", action="store_true", parser.add_option("--traffic", action="store_true",
help="per frame traffic") help="per frame traffic")
parser.add_option("--record", parser.add_option("--record",
help="record sessions to FILE.[session_number]", metavar="FILE") help="record sessions to FILE.[session_number]", metavar="FILE")
parser.add_option("--daemon", "-D", parser.add_option("--daemon", "-D",
dest="daemon", action="store_true", dest="daemon", action="store_true",
help="become a daemon (background process)") help="become a daemon (background process)")
parser.add_option("--run-once", action="store_true", parser.add_option("--run-once", action="store_true",
help="handle a single WebSocket connection and exit") help="handle a single WebSocket connection and exit")
parser.add_option("--timeout", type=int, default=0, parser.add_option("--timeout", type=int, default=0,
help="after TIMEOUT seconds exit when not connected") help="after TIMEOUT seconds exit when not connected")
parser.add_option("--idle-timeout", type=int, default=0, parser.add_option("--idle-timeout", type=int, default=0,
help="server exits after TIMEOUT seconds if there are no " help="server exits after TIMEOUT seconds if there are no "
"active connections") "active connections")
parser.add_option("--cert", default="self.pem", parser.add_option("--cert", default="self.pem",
help="SSL certificate file") help="SSL certificate file")
parser.add_option("--key", default=None, parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)") help="SSL key file (if separate from cert)")
parser.add_option("--key-password", default=None, parser.add_option("--key-password", default=None,
help="SSL key password") help="SSL key password")
parser.add_option("--ssl-only", action="store_true", parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted client connections") help="disallow non-encrypted client connections")
parser.add_option("--ssl-target", action="store_true", parser.add_option("--ssl-target", action="store_true",
help="connect to SSL target as SSL client") help="connect to SSL target as SSL client")
parser.add_option("--verify-client", action="store_true", parser.add_option("--verify-client", action="store_true",
help="require encrypted client to present a valid certificate " help="require encrypted client to present a valid certificate "
"(needs Python 2.7.9 or newer or Python 3.4 or newer)") "(needs Python 2.7.9 or newer or Python 3.4 or newer)")
parser.add_option("--cafile", metavar="FILE", parser.add_option("--cafile", metavar="FILE",
help="file of concatenated certificates of authorities trusted " help="file of concatenated certificates of authorities trusted "
"for validating clients (only effective with --verify-client). " "for validating clients (only effective with --verify-client). "
"If omitted, system default list of CAs is used.") "If omitted, system default list of CAs is used.")
parser.add_option("--ssl-version", type="choice", default="default", parser.add_option("--ssl-version", type="choice", default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store", choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)") help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)")
parser.add_option("--ssl-ciphers", action="store", parser.add_option("--ssl-ciphers", action="store",
help="list of ciphers allowed for connection. For a list of " help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`") "supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen", parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None) help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None, parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)") help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target", parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE") help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd", parser.add_option("--inetd",
help="inetd mode, receive listening socket from stdin", action="store_true") help="inetd mode, receive listening socket from stdin", action="store_true")
parser.add_option("--web", default=None, metavar="DIR", parser.add_option("--web", default=None, metavar="DIR",
help="run webserver on same port. Serve files from DIR.") help="run webserver on same port. Serve files from DIR.")
parser.add_option("--web-auth", action="store_true", parser.add_option("--web-auth", action="store_true",
help="require authentication to access webserver.") help="require authentication to access webserver.")
parser.add_option("--wrap-mode", default="exit", metavar="MODE", parser.add_option("--wrap-mode", default="exit", metavar="MODE",
choices=["exit", "ignore", "respawn"], choices=["exit", "ignore", "respawn"],
help="action to take when the wrapped program exits " help="action to take when the wrapped program exits "
"or daemonizes: exit (default), ignore, respawn") "or daemonizes: exit (default), ignore, respawn")
parser.add_option("--prefer-ipv6", "-6", parser.add_option("--prefer-ipv6", "-6",
action="store_true", dest="source_is_ipv6", action="store_true", dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr") help="prefer IPv6 when resolving source_addr")
parser.add_option("--libserver", action="store_true", parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine") help="use Python library SocketServer engine")
parser.add_option("--target-config", metavar="FILE", parser.add_option("--target-config", metavar="FILE",
dest="target_cfg", dest="target_cfg",
help="Configuration file containing valid targets " help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a " "in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form " "directory containing configuration files of this form "
"(DEPRECATED: use `--token-plugin TokenFile --token-source " "(DEPRECATED: use `--token-plugin TokenFile --token-source "
" path/to/token/file` instead)") " path/to/token/file` instead)")
parser.add_option("--token-plugin", default=None, metavar="CLASS", parser.add_option("--token-plugin", default=None, metavar="CLASS",
help="use a Python class, usually one from websockify.token_plugins, " help="use a Python class, usually one from websockify.token_plugins, "
"such as TokenFile, to process tokens into host:port pairs") "such as TokenFile, to process tokens into host:port pairs")
@ -547,13 +563,13 @@ def websockify_init():
help="an argument to be passed to the auth plugin " help="an argument to be passed to the auth plugin "
"on instantiation") "on instantiation")
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL", parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds") help="send a ping to the client every INTERVAL seconds")
parser.add_option("--log-file", metavar="FILE", parser.add_option("--log-file", metavar="FILE",
dest="log_file", dest="log_file",
help="File where logs will be saved") help="File where logs will be saved")
parser.add_option("--syslog", default=None, metavar="SERVER", parser.add_option("--syslog", default=None, metavar="SERVER",
help="Log to syslog server. SERVER can be local socket, " help="Log to syslog server. SERVER can be local socket, "
"such as /dev/log, or a UDP host:port pair.") "such as /dev/log, or a UDP host:port pair.")
parser.add_option("--legacy-syslog", action="store_true", parser.add_option("--legacy-syslog", action="store_true",
help="Use the old syslog protocol instead of RFC 5424. " help="Use the old syslog protocol instead of RFC 5424. "
"Use this if the messages produced by websockify seem abnormal.") "Use this if the messages produced by websockify seem abnormal.")
@ -562,9 +578,7 @@ def websockify_init():
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
# Validate options. # Validate options.
if opts.token_source and not opts.token_plugin: if opts.token_source and not opts.token_plugin:
parser.error("You must use --token-plugin to use --token-source") parser.error("You must use --token-plugin to use --token-source")
@ -583,11 +597,9 @@ def websockify_init():
if opts.legacy_syslog and not opts.syslog: if opts.legacy_syslog and not opts.syslog:
parser.error("You must use --syslog to use --legacy-syslog") parser.error("You must use --syslog to use --legacy-syslog")
opts.ssl_options = select_ssl_version(opts.ssl_version) opts.ssl_options = select_ssl_version(opts.ssl_version)
del opts.ssl_version del opts.ssl_version
if opts.log_file: if opts.log_file:
# Setup logging to user-specified file. # Setup logging to user-specified file.
opts.log_file = os.path.abspath(opts.log_file) opts.log_file = os.path.abspath(opts.log_file)
@ -638,7 +650,6 @@ def websockify_init():
root = logging.getLogger() root = logging.getLogger()
root.setLevel(logging.DEBUG) root.setLevel(logging.DEBUG)
# Transform to absolute path as daemon may chdir # Transform to absolute path as daemon may chdir
if opts.target_cfg: if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg) opts.target_cfg = os.path.abspath(opts.target_cfg)
@ -655,7 +666,7 @@ def websockify_init():
opts.wrap_cmd = None opts.wrap_cmd = None
if not websockifyserver.ssl and opts.ssl_target: if not websockifyserver.ssl and opts.ssl_target:
parser.error("SSL target requested and Python SSL module not loaded."); parser.error("SSL target requested and Python SSL module not loaded.")
if opts.ssl_only and not os.path.exists(opts.cert): if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert) parser.error("SSL only and %s not found" % opts.cert)
@ -708,7 +719,7 @@ def websockify_init():
except ValueError: except ValueError:
parser.error("Error parsing target port") parser.error("Error parsing target port")
if len(args) > 0 and opts.wrap_cmd == None: if len(args) > 0 and opts.wrap_cmd is None:
parser.error("Too many arguments") parser.error("Too many arguments")
if opts.token_plugin is not None: if opts.token_plugin is not None:
@ -759,32 +770,32 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options # Save off proxy specific options
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None) self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None) self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None) self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None) self.ssl_target = kwargs.pop('ssl_target', None)
self.token_plugin = kwargs.pop('token_plugin', None) self.token_plugin = kwargs.pop('token_plugin', None)
self.auth_plugin = kwargs.pop('auth_plugin', None) self.auth_plugin = kwargs.pop('auth_plugin', None)
self.heartbeat = kwargs.pop('heartbeat', None) self.heartbeat = kwargs.pop('heartbeat', None)
self.token_plugin = None self.token_plugin = None
self.auth_plugin = None self.auth_plugin = None
self.daemon = False self.daemon = False
# Server configuration # Server configuration
listen_host = kwargs.pop('listen_host', '') listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None) listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '') web = kwargs.pop('web', '')
# Configuration affecting base request handler # Configuration affecting base request handler
self.only_upgrade = not web self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False) self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '') record = kwargs.pop('record', '')
if record: if record:
self.record = os.path.abspath(record) self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False) self.run_once = kwargs.pop('run_once', False)
self.handler_id = 0 self.handler_id = 0
for arg in kwargs.keys(): for arg in kwargs.keys():
@ -795,7 +806,6 @@ class LibProxyServer(ThreadingMixIn, HTTPServer):
super().__init__((listen_host, listen_port), RequestHandlerClass) super().__init__((listen_host, listen_port), RequestHandlerClass)
def process_request(self, request, client_address): def process_request(self, request, client_address):
"""Override process_request to implement a counter""" """Override process_request to implement a counter"""
self.handler_id += 1 self.handler_id += 1

View File

@ -10,7 +10,8 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
import sys import sys
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError from websockify.websocket import WebSocket
class HttpWebSocket(WebSocket): class HttpWebSocket(WebSocket):
"""Class to glue websocket and http request functionality together""" """Class to glue websocket and http request functionality together"""
@ -93,18 +94,21 @@ class WebSocketRequestHandlerMixIn:
def handle_websocket(self): def handle_websocket(self):
"""Handle a WebSocket connection. """Handle a WebSocket connection.
This is called when the WebSocket is ready to be used. A This is called when the WebSocket is ready to be used. A
sub-class should perform the necessary communication here and sub-class should perform the necessary communication here and
return once done. return once done.
""" """
pass pass
# Convenient ready made classes # Convenient ready made classes
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn, class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
BaseHTTPRequestHandler): BaseHTTPRequestHandler):
pass pass
class WebSocketServer(HTTPServer): class WebSocketServer(HTTPServer):
pass pass

View File

@ -12,18 +12,29 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import os, sys, time, errno, signal, socket, select, logging import errno
import multiprocessing
from http.server import SimpleHTTPRequestHandler from http.server import SimpleHTTPRequestHandler
import logging
import multiprocessing
import os
import select
import signal
import socket
import sys
import time
# Degraded functionality if these imports are missing # Degraded functionality if these imports are missing
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'), try:
('resource', 'daemonizing is disabled')]: import ssl
try: except ImportError:
globals()[mod] = __import__(mod) ssl = None
except ImportError: print("WARNING: no 'ssl' module, TLS/SSL/wss is disabled")
globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg)) try:
import resource
except ImportError:
resource = None
print("WARNING: no 'resource' module, daemonizing is disabled")
if sys.platform == 'win32': if sys.platform == 'win32':
# make sockets pickle-able/inheritable # make sockets pickle-able/inheritable
@ -32,6 +43,7 @@ if sys.platform == 'win32':
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocketserver import WebSocketRequestHandlerMixIn from websockify.websocketserver import WebSocketRequestHandlerMixIn
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass): class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
def select_subprotocol(self, protocols): def select_subprotocol(self, protocols):
# Handle old websockify clients that still specify a sub-protocol # Handle old websockify clients that still specify a sub-protocol
@ -40,6 +52,7 @@ class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
else: else:
return '' return ''
# HTTP handler with WebSocket upgrade support # HTTP handler with WebSocket upgrade support
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler): class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
""" """
@ -73,7 +86,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
self.daemon = getattr(server, "daemon", False) self.daemon = getattr(server, "daemon", False)
self.record = getattr(server, "record", False) self.record = getattr(server, "record", False)
self.run_once = getattr(server, "run_once", False) self.run_once = getattr(server, "run_once", False)
self.rec = None self.rec = None
self.handler_id = getattr(server, "handler_id", False) self.handler_id = getattr(server, "handler_id", False)
self.file_only = getattr(server, "file_only", False) self.file_only = getattr(server, "file_only", False)
self.traffic = getattr(server, "traffic", False) self.traffic = getattr(server, "traffic", False)
@ -124,7 +137,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
fully sent, in which case the caller should call again when fully sent, in which case the caller should call again when
the socket is ready. """ the socket is ready. """
tdelta = int(time.time()*1000) - self.start_time tdelta = int(time.time() * 1000) - self.start_time
if bufs: if bufs:
for buf in bufs: for buf in bufs:
@ -155,7 +168,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
closed = False closed = False
bufs = [] bufs = []
tdelta = int(time.time()*1000) - self.start_time tdelta = int(time.time() * 1000) - self.start_time
while True: while True:
try: try:
@ -207,8 +220,8 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
self.server.ws_connection = True self.server.ws_connection = True
# Initialize per client settings # Initialize per client settings
self.send_parts = [] self.send_parts = []
self.recv_part = None self.recv_part = None
self.start_time = int(time.time()*1000) self.start_time = int(time.time() * 1000)
# client_address is empty with, say, UNIX domain sockets # client_address is empty with, say, UNIX domain sockets
client_addr = "" client_addr = ""
@ -333,47 +346,47 @@ class WebSockifyServer():
pass pass
def __init__(self, RequestHandlerClass, listen_fd=None, def __init__(self, RequestHandlerClass, listen_fd=None,
listen_host='', listen_port=None, source_is_ipv6=False, listen_host='', listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', key_password=None, ssl_only=None, verbose=False, cert='', key='', key_password=None, ssl_only=None,
verify_client=False, cafile=None, verify_client=False, cafile=None,
daemon=False, record='', web='', web_auth=False, daemon=False, record='', web='', web_auth=False,
file_only=False, file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False, run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0, tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None): unix_listen=None, unix_listen_mode=None):
# settings # settings
self.RequestHandlerClass = RequestHandlerClass self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose self.verbose = verbose
self.listen_fd = listen_fd self.listen_fd = listen_fd
self.unix_listen = unix_listen self.unix_listen = unix_listen
self.unix_listen_mode = unix_listen_mode self.unix_listen_mode = unix_listen_mode
self.listen_host = listen_host self.listen_host = listen_host
self.listen_port = listen_port self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6 self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options self.ssl_options = ssl_options
self.verify_client = verify_client self.verify_client = verify_client
self.daemon = daemon self.daemon = daemon
self.run_once = run_once self.run_once = run_once
self.timeout = timeout self.timeout = timeout
self.idle_timeout = idle_timeout self.idle_timeout = idle_timeout
self.traffic = traffic self.traffic = traffic
self.file_only = file_only self.file_only = file_only
self.web_auth = web_auth self.web_auth = web_auth
self.launch_time = time.time() self.launch_time = time.time()
self.ws_connection = False self.ws_connection = False
self.handler_id = 1 self.handler_id = 1
self.terminating = False self.terminating = False
self.logger = self.get_logger() self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl self.tcp_keepintvl = tcp_keepintvl
# keyfile path must be None if not specified # keyfile path must be None if not specified
self.key = None self.key = None
@ -403,13 +416,13 @@ class WebSockifyServer():
# Show configuration # Show configuration
self.msg("WebSocket server settings:") self.msg("WebSocket server settings:")
if self.listen_fd != None: if self.listen_fd is not None:
self.msg(" - Listen for inetd connections") self.msg(" - Listen for inetd connections")
elif self.unix_listen != None: elif self.unix_listen is not None:
self.msg(" - Listen on unix socket %s", self.unix_listen) self.msg(" - Listen on unix socket %s", self.unix_listen)
else: else:
self.msg(" - Listen on %s:%s", self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port) self.listen_host, self.listen_port)
if self.web: if self.web:
if self.file_only: if self.file_only:
self.msg(" - Web server (no directory listings). Web root: %s", self.web) self.msg(" - Web server (no directory listings). Web root: %s", self.web)
@ -442,7 +455,7 @@ class WebSockifyServer():
@staticmethod @staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False, def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False, unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None, use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None): tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6 """ Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set, address. Create a socket. Bind to it if listen is set,
@ -454,7 +467,7 @@ class WebSockifyServer():
if connect and not (port or unix_socket): if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port") raise Exception("Connect mode requires a port")
if use_ssl and not ssl: if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded."); raise Exception("SSL socket requested but Python SSL module not loaded.")
if not connect and use_ssl: if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)") raise Exception("SSL only supported in connect mode (for now)")
if not connect: if not connect:
@ -462,7 +475,7 @@ class WebSockifyServer():
if not unix_socket: if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags) socket.IPPROTO_TCP, flags)
if not addrs: if not addrs:
raise Exception("Could not resolve host '%s'" % host) raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0]) addrs.sort(key=lambda x: x[0])
@ -470,7 +483,7 @@ class WebSockifyServer():
addrs.reverse() addrs.reverse()
sock = socket.socket(addrs[0][0], addrs[0][1]) sock = socket.socket(addrs[0][0], addrs[0][1])
if tcp_keepalive: if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt: if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
@ -513,7 +526,7 @@ class WebSockifyServer():
@staticmethod @staticmethod
def daemonize(keepfd=None, chdir='/'): def daemonize(keepfd=None, chdir='/'):
if keepfd is None: if keepfd is None:
keepfd = [] keepfd = []
@ -526,9 +539,11 @@ class WebSockifyServer():
os.setuid(os.getuid()) # relinquish elevations os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize # Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits if os.fork() > 0: # Parent exits
os.setsid() # Obtain new process group os._exit(0)
if os.fork() > 0: os._exit(0) # Parent exits os.setsid() # Obtain new process group
if os.fork() > 0: # Parent exits
os._exit(0)
# Signal handling # Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN) signal.signal(signal.SIGTERM, signal.SIG_IGN)
@ -536,14 +551,16 @@ class WebSockifyServer():
# Close open files # Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256 if maxfd == resource.RLIM_INFINITY:
maxfd = 256
for fd in reversed(range(maxfd)): for fd in reversed(range(maxfd)):
try: try:
if fd not in keepfd: if fd not in keepfd:
os.close(fd) os.close(fd)
except OSError: except OSError:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise if exc.errno != errno.EBADF:
raise
# Redirect I/O to /dev/null # Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
@ -572,7 +589,6 @@ class WebSockifyServer():
# Peek, but do not read the data so that we have a opportunity # Peek, but do not read the data so that we have a opportunity
# to SSL wrap the socket first # to SSL wrap the socket first
handshake = sock.recv(1024, socket.MSG_PEEK) handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake)
if not handshake: if not handshake:
raise self.EClose("") raise self.EClose("")
@ -599,8 +615,8 @@ class WebSockifyServer():
else: else:
context.set_default_verify_paths() context.set_default_verify_paths()
retsock = context.wrap_socket( retsock = context.wrap_socket(
sock, sock,
server_side=True) server_side=True)
except ssl.SSLError: except ssl.SSLError:
_, x, _ = sys.exc_info() _, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF: if x.args[0] == ssl.SSL_ERROR_EOF:
@ -644,17 +660,16 @@ class WebSockifyServer():
""" Same as msg() but as warning. """ """ Same as msg() but as warning. """
self.logger.log(logging.WARNING, *args, **kwargs) self.logger.log(logging.WARNING, *args, **kwargs)
# #
# Events that can/should be overridden in sub-classes # Events that can/should be overridden in sub-classes
# #
def started(self): def started(self):
""" Called after WebSockets startup """ """ Called after WebSockets startup """
self.vmsg("WebSockets server started") self.vmsg("WebSockets server started")
def poll(self): def poll(self):
""" Run periodically while waiting for connections. """ """ Run periodically while waiting for connections. """
#self.vmsg("Running poll()")
pass pass
def terminate(self): def terminate(self):
@ -735,9 +750,9 @@ class WebSockifyServer():
""" """
try: try:
if self.listen_fd != None: if self.listen_fd is not None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM) lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
elif self.unix_listen != None: elif self.unix_listen is not None:
lsock = self.socket(host=None, lsock = self.socket(host=None,
unix_socket=self.unix_listen, unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode, unix_socket_mode=self.unix_listen_mode,
@ -781,7 +796,7 @@ class WebSockifyServer():
try: try:
try: try:
startsock = None startsock = None
pid = err = 0 err = 0
child_count = 0 child_count = 0
# Collect zombie child processes # Collect zombie child processes
@ -790,7 +805,7 @@ class WebSockifyServer():
time_elapsed = time.time() - self.launch_time time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout: if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s' self.msg('listener exit due to --timeout %s'
% self.timeout) % self.timeout)
break break
if self.idle_timeout: if self.idle_timeout:
@ -803,7 +818,7 @@ class WebSockifyServer():
if idle_time > self.idle_timeout and child_count == 0: if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s' self.msg('listener exit due to --idle-timeout %s'
% self.idle_timeout) % self.idle_timeout)
break break
try: try:
@ -813,8 +828,8 @@ class WebSockifyServer():
if lsock in ready: if lsock in ready:
startsock, address = lsock.accept() startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch # Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None: if self.unix_listen is not None:
address = [ self.unix_listen ] address = [self.unix_listen]
else: else:
continue continue
except self.Terminate: except self.Terminate:
@ -836,15 +851,15 @@ class WebSockifyServer():
if self.run_once: if self.run_once:
# Run in same process if run_once # Run in same process if run_once
self.top_new_client(startsock, address) self.top_new_client(startsock, address)
if self.ws_connection : if self.ws_connection:
self.msg('%s: exiting due to --run-once' self.msg('%s: exiting due to --run-once'
% address[0]) % address[0])
break break
else: else:
self.vmsg('%s: new handler Process' % address[0]) self.vmsg('%s: new handler Process' % address[0])
p = multiprocessing.Process( p = multiprocessing.Process(
target=self.top_new_client, target=self.top_new_client,
args=(startsock, address)) args=(startsock, address))
p.start() p.start()
# child will not return # child will not return
@ -879,5 +894,3 @@ class WebSockifyServer():
# Restore signals # Restore signals
for sig, func in original_signals.items(): for sig, func in original_signals.items():
signal.signal(sig, func) signal.signal(sig, func)