Merge pull request #111 from astrand/master

Refactor to use standard SocketServer RequestHandler design.
This commit is contained in:
astrand 2013-12-19 01:31:14 -08:00
commit 38db12c2b0
4 changed files with 317 additions and 208 deletions

View File

@ -12,9 +12,9 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
import os, sys, select, optparse import os, sys, select, optparse
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocketServer from websockify.websocket import WebSocketServer, WebSocketRequestHandler
class WebSocketEcho(WebSocketServer): class WebSocketEcho(WebSocketRequestHandler):
""" """
WebSockets server that echos back whatever is received from the WebSockets server that echos back whatever is received from the
client. """ client. """
@ -49,7 +49,6 @@ class WebSocketEcho(WebSocketServer):
if closed: if closed:
self.send_close() self.send_close()
raise self.EClose(closed)
if __name__ == '__main__': if __name__ == '__main__':
parser = optparse.OptionParser(usage="%prog [options] listen_port") parser = optparse.OptionParser(usage="%prog [options] listen_port")
@ -70,6 +69,6 @@ if __name__ == '__main__':
parser.error("Invalid arguments") parser.error("Invalid arguments")
opts.web = "." opts.web = "."
server = WebSocketEcho(**opts.__dict__) server = WebSocketServer(WebSocketEcho, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -8,28 +8,30 @@ given a sequence number. Any errors are reported and counted.
import sys, os, select, random, time, optparse import sys, os, select, random, time, optparse
sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
from websockify.websocket import WebSocketServer from websockify.websocket import WebSocketServer, WebSocketRequestHandler
class WebSocketLoad(WebSocketServer): class WebSocketLoadServer(WebSocketServer):
buffer_size = 65536
max_packet_size = 10000
recv_cnt = 0 recv_cnt = 0
send_cnt = 0 send_cnt = 0
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.errors = 0
self.delay = kwargs.pop('delay') self.delay = kwargs.pop('delay')
WebSocketServer.__init__(self, *args, **kwargs)
class WebSocketLoad(WebSocketRequestHandler):
max_packet_size = 10000
def new_websocket_client(self):
print "Prepopulating random array" print "Prepopulating random array"
self.rand_array = [] self.rand_array = []
for i in range(0, self.max_packet_size): for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9)) self.rand_array.append(random.randint(0, 9))
WebSocketServer.__init__(self, *args, **kwargs) self.errors = 0
def new_websocket_client(self):
self.send_cnt = 0 self.send_cnt = 0
self.recv_cnt = 0 self.recv_cnt = 0
@ -61,14 +63,13 @@ class WebSocketLoad(WebSocketServer):
if closed: if closed:
self.send_close() self.send_close()
raise self.EClose(closed)
now = time.time() * 1000 now = time.time() * 1000
if client in outs: if client in outs:
if c_pend: if c_pend:
last_send = now last_send = now
c_pend = self.send_frames() c_pend = self.send_frames()
elif now > (last_send + self.delay): elif now > (last_send + self.server.delay):
last_send = now last_send = now
c_pend = self.send_frames([self.generate()]) c_pend = self.send_frames([self.generate()])
@ -162,6 +163,6 @@ if __name__ == '__main__':
parser.error("Invalid arguments") parser.error("Invalid arguments")
opts.web = "." opts.web = "."
server = WebSocketLoad(**opts.__dict__) server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__)
server.start_server() server.start_server()

View File

@ -65,30 +65,51 @@ if multiprocessing and sys.platform == 'win32':
import multiprocessing.reduction import multiprocessing.reduction
class WebSocketServer(object): # HTTP handler with WebSocket upgrade support
class WebSocketRequestHandler(SimpleHTTPRequestHandler):
""" """
WebSockets server class. WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler.
Must be sub-classed with new_websocket_client method definition. Must be sub-classed with new_websocket_client method definition.
""" The request handler can be configured by setting optional
attributes on the server object:
log_prefix = "websocket" * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled,
only websocket is allowed.
* verbose: If true, verbose logging is activated.
* daemon: Running as daemon, do not write to console etc
* record: Record raw frame data as JavaScript array into specified filename
* run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename
"""
buffer_size = 65536 buffer_size = 65536
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r server_version = "WebSockify"
Upgrade: websocket\r
Connection: Upgrade\r protocol_version = "HTTP/1.1"
Sec-WebSocket-Accept: %s\r
"""
# An exception while the WebSocket client was connected # An exception while the WebSocket client was connected
class CClose(Exception): class CClose(Exception):
pass pass
class Terminate(Exception): def __init__(self, req, addr, server):
pass # Retrieve a few configuration variables from the server
self.only_upgrade = getattr(server, "only_upgrade", False)
self.verbose = getattr(server, "verbose", False)
self.daemon = getattr(server, "daemon", False)
self.record = getattr(server, "record", False)
self.run_once = getattr(server, "run_once", False)
self.rec = None
self.handler_id = getattr(server, "handler_id", False)
self.file_only = getattr(server, "file_only", False)
self.traffic = getattr(server, "traffic", False)
self.logger = getattr(server, "logger", None)
if self.logger is None:
self.logger = WebSocketServer.get_logger()
SimpleHTTPRequestHandler.__init__(self, req, addr, server)
@staticmethod @staticmethod
def unmask(buf, hlen, plen): def unmask(buf, hlen, plen):
@ -213,7 +234,7 @@ Sec-WebSocket-Accept: %s\r
# Process 1 frame # Process 1 frame
if f['masked']: if f['masked']:
# unmask payload # unmask payload
f['payload'] = WebSocketServer.unmask(buf, f['hlen'], f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'],
f['length']) f['length'])
else: else:
logger.debug("Unmasked frame: %s" % repr(buf)) logger.debug("Unmasked frame: %s" % repr(buf))
@ -237,9 +258,33 @@ Sec-WebSocket-Accept: %s\r
# #
# Main WebSocketServer methods # WebSocketRequestHandler logging/output functions
# #
def print_traffic(self, token="."):
""" Show traffic flow mode. """
if self.traffic:
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg, *args, **kwargs):
""" Output message with handler_id prefix. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
def vmsg(self, msg, *args, **kwargs):
""" Same as msg() but as debug. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Same as msg() but as warning. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
#
# Main WebSocketRequestHandler methods
#
def send_frames(self, bufs=None): def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already """ Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued queued will be sent first. If buf is not set then only queued
@ -322,7 +367,7 @@ Sec-WebSocket-Accept: %s\r
start = frame['hlen'] start = frame['hlen']
end = frame['hlen'] + frame['length'] end = frame['hlen'] + frame['length']
if frame['masked']: if frame['masked']:
recbuf = WebSocketServer.unmask(buf, frame['hlen'], recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'],
frame['length']) frame['length'])
else: else:
recbuf = buf[frame['hlen']:frame['hlen'] + recbuf = buf[frame['hlen']:frame['hlen'] +
@ -347,9 +392,8 @@ Sec-WebSocket-Accept: %s\r
buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False)
self.request.send(buf) self.request.send(buf)
def do_websocket_handshake(self, headers, path): def do_websocket_handshake(self):
h = self.headers = headers h = self.headers
self.path = path
prot = 'WebSocket-Protocol' prot = 'WebSocket-Protocol'
protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
@ -364,7 +408,8 @@ Sec-WebSocket-Accept: %s\r
if ver in ['7', '8', '13']: if ver in ['7', '8', '13']:
self.version = "hybi-%02d" % int(ver) self.version = "hybi-%02d" % int(ver)
else: else:
raise self.EClose('Unsupported protocol version %s' % ver) self.send_error(400, "Unsupported protocol version %s" % ver)
return False
key = h['Sec-WebSocket-Key'] key = h['Sec-WebSocket-Key']
@ -374,42 +419,158 @@ Sec-WebSocket-Accept: %s\r
elif 'base64' in protocols: elif 'base64' in protocols:
self.base64 = True self.base64 = True
else: else:
raise self.EClose("Client must support 'binary' or 'base64' protocol") self.send_error(400, "Client must support 'binary' or 'base64' protocol")
return False
# Generate the hash value for the accept header # Generate the hash value for the accept header
accept = b64encode(sha1(s2b(key + self.GUID)).digest()) accept = b64encode(sha1(s2b(key + self.GUID)).digest())
response = self.server_handshake_hybi % b2s(accept) self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
self.send_header("Connection", "Upgrade")
self.send_header("Sec-WebSocket-Accept", b2s(accept))
if self.base64: if self.base64:
response += "Sec-WebSocket-Protocol: base64\r\n" self.send_header("Sec-WebSocket-Protocol", "base64")
else: else:
response += "Sec-WebSocket-Protocol: binary\r\n" self.send_header("Sec-WebSocket-Protocol", "binary")
response += "\r\n" self.end_headers()
return True
else: else:
raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.")
return response return False
def handle_websocket(self):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
if not self.do_websocket_handshake():
return False
# Indicate to server that a Websocket upgrade was done
self.server.ws_connection = True
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.start_time = int(time.time()*1000)
# client_address is empty with, say, UNIX domain sockets
client_addr = ""
is_ssl = False
try:
client_addr = self.client_address[0]
is_ssl = self.client_address[2]
except IndexError:
pass
if is_ssl:
self.stype = "SSL/TLS (wss://)"
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection" % (client_addr,
self.stype))
self.log_message("%s: Version %s, base64: '%s'" % (client_addr,
self.version, self.base64))
if self.path != '/':
self.log_message("%s: Path: '%s'" % (client_addr, self.path))
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.log_message("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
try:
self.new_websocket_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
self.send_close(exc.args[0], exc.args[1])
return True
else:
return False
def do_GET(self):
"""Handle GET request. Calls handle_websocket(). If unsuccessful,
and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called."""
if not self.handle_websocket():
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_GET(self)
def list_directory(self, path):
if self.file_only:
self.send_error(404, "No such file")
else:
return SimpleHTTPRequestHandler.list_directory(self, path)
def new_websocket_client(self): def new_websocket_client(self):
""" Do something with a WebSockets client connection. """ """ Do something with a WebSockets client connection. """
raise("WebSocketServer.new_websocket_client() must be overloaded") raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_HEAD(self)
def finish(self):
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
def handle(self):
# When using run_once, we have a single process, so
# we cannot loop in BaseHTTPRequestHandler.handle; we
# must return and handle new connections
if self.run_once:
self.handle_one_request()
else:
SimpleHTTPRequestHandler.handle(self)
def log_request(self, code='-', size='-'):
if self.verbose:
SimpleHTTPRequestHandler.log_request(self, code, size)
class WebSocketServer(object):
"""
WebSockets server class.
As an alternative, the standard library SocketServer can be used
"""
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
log_prefix = "websocket"
# An exception before the WebSocket connection was established # An exception before the WebSocket connection was established
class EClose(Exception): class EClose(Exception):
pass pass
def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, class Terminate(Exception):
pass
def __init__(self, RequestHandlerClass, listen_host='',
listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', ssl_only=None, verbose=False, cert='', key='', ssl_only=None,
daemon=False, record='', web='', daemon=False, record='', web='',
file_only=False, no_parent=False, file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False, run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None): tcp_keepintvl=None):
# settings # settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose self.verbose = verbose
self.listen_host = listen_host self.listen_host = listen_host
self.listen_port = listen_port self.listen_port = listen_port
@ -423,12 +584,8 @@ Sec-WebSocket-Accept: %s\r
self.launch_time = time.time() self.launch_time = time.time()
self.ws_connection = False self.ws_connection = False
self.i_am_client = False
self.handler_id = 1 self.handler_id = 1
self.file_only = file_only
self.no_parent = no_parent
self.logger = self.get_logger() self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt self.tcp_keepcnt = tcp_keepcnt
@ -447,6 +604,7 @@ Sec-WebSocket-Accept: %s\r
if self.web: if self.web:
os.chdir(self.web) os.chdir(self.web)
self.only_upgrade = not self.web
# Sanity checks # Sanity checks
if not ssl and self.ssl_only: if not ssl and self.ssl_only:
@ -593,7 +751,6 @@ Sec-WebSocket-Accept: %s\r
- Send a WebSockets handshake server response. - Send a WebSockets handshake server response.
- Return the socket for this WebSocket client. - Return the socket for this WebSocket client.
""" """
stype = ""
ready = select.select([sock], [], [], 3)[0] ready = select.select([sock], [], [], 3)[0]
@ -637,43 +794,19 @@ Sec-WebSocket-Accept: %s\r
else: else:
raise raise
self.scheme = "wss"
stype = "SSL/TLS (wss://)"
elif self.ssl_only: elif self.ssl_only:
raise self.EClose("non-SSL connection received but disallowed") raise self.EClose("non-SSL connection received but disallowed")
else: else:
retsock = sock retsock = sock
self.scheme = "ws"
stype = "Plain non-SSL (ws://)"
wsh = WSRequestHandler(retsock, address, not self.web, # If the address is like (host, port), we are extending it
self.file_only, self.no_parent) # with a flag indicating SSL. Not many other options
if wsh.last_code == 101: # available...
# Continue on to handle WebSocket upgrade if len(address) == 2:
pass address = (address[0], address[1], (retsock != sock))
elif wsh.last_code == 405:
raise self.EClose("Normal web request received but disallowed")
elif wsh.last_code < 200 or wsh.last_code >= 300:
raise self.EClose(wsh.last_message)
elif self.verbose:
raise self.EClose(wsh.last_message)
else:
raise self.EClose("")
response = self.do_websocket_handshake(wsh.headers, wsh.path) self.RequestHandlerClass(retsock, address, self)
self.msg("%s: %s WebSocket connection" % (address[0], stype))
self.msg("%s: Version %s, base64: '%s'" % (address[0],
self.version, self.base64))
if self.path != '/':
self.msg("%s: Path: '%s'" % (address[0], self.path))
# Send server WebSockets handshake response
#self.msg("sending response [%s]" % response)
retsock.send(s2b(response))
# Return the WebSockets socket which may be SSL wrapped # Return the WebSockets socket which may be SSL wrapped
return retsock return retsock
@ -681,32 +814,18 @@ Sec-WebSocket-Accept: %s\r
# #
# WebSocketServer logging/output functions # WebSocketServer logging/output functions
# #
def print_traffic(self, token="."):
""" Show traffic flow mode. """
if self.traffic:
sys.stdout.write(token)
sys.stdout.flush()
def log(self, lvl, msg, *args, **kwargs):
""" Wrapper around python logging """
prefix = ""
if self.i_am_client:
prefix = "% 3d: " % self.handler_id
self.logger.log(lvl, "%s%s" % (prefix, msg),
*args, **kwargs)
def msg(self, *args, **kwargs): def msg(self, *args, **kwargs):
""" Output message with handler_id prefix. """ """ Output message as info """
self.log(logging.INFO, *args, **kwargs) self.logger.log(logging.INFO, *args, **kwargs)
def vmsg(self, *args, **kwargs): def vmsg(self, *args, **kwargs):
""" Same as msg() but as debug. """ """ Same as msg() but as debug. """
self.log(logging.DEBUG, *args, **kwargs) self.logger.log(logging.DEBUG, *args, **kwargs)
def warn(self, *args, **kwargs): def warn(self, *args, **kwargs):
""" Same as msg() but as warning. """ """ Same as msg() but as warning. """
self.log(logging.WARN, *args, **kwargs) self.logger.log(logging.WARN, *args, **kwargs)
# #
@ -748,38 +867,11 @@ Sec-WebSocket-Accept: %s\r
def top_new_client(self, startsock, address): def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """ """ Do something with a WebSockets client connection. """
# Initialize per client settings
self.i_am_client = True
self.send_parts = []
self.recv_part = None
self.base64 = False
self.rec = None
self.start_time = int(time.time()*1000)
# handler process # handler process
client = None
try: try:
try: try:
self.request = self.do_handshake(startsock, address) client = self.do_handshake(startsock, address)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
self.ws_connection = True
self.new_websocket_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
if self.request:
self.send_close(exc.args[0], exc.args[1])
except self.EClose: except self.EClose:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
# Connection was not a WebSockets connection # Connection was not a WebSockets connection
@ -792,14 +884,11 @@ Sec-WebSocket-Accept: %s\r
self.msg("handler exception: %s" % str(exc)) self.msg("handler exception: %s" % str(exc))
self.vmsg("exception", exc_info=True) self.vmsg("exception", exc_info=True)
finally: finally:
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
if self.request and self.request != startsock: if client and client != startsock:
# Close the SSL wrapped socket # Close the SSL wrapped socket
# Original socket closed by caller # Original socket closed by caller
self.request.close() client.close()
def start_server(self): def start_server(self):
""" """
@ -841,7 +930,6 @@ Sec-WebSocket-Accept: %s\r
while True: while True:
try: try:
try: try:
self.request = None
startsock = None startsock = None
pid = err = 0 pid = err = 0
child_count = 0 child_count = 0
@ -940,40 +1028,3 @@ Sec-WebSocket-Accept: %s\r
signal.signal(sig, func) signal.signal(sig, func)
# HTTP handler with WebSocket upgrade support
class WSRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, req, addr, only_upgrade=False, file_only=False,
no_parent=False):
self.only_upgrade = only_upgrade # only allow upgrades
self.webroot = os.path.realpath(".")
self.file_only = file_only
self.no_parent = no_parent
SimpleHTTPRequestHandler.__init__(self, req, addr, object())
def do_GET(self):
abspath = os.path.realpath("." + (self.path.split('?')[0]))
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
# Just indicate that an WebSocket upgrade is needed
self.last_code = 101
self.last_message = "101 Switching Protocols"
elif self.only_upgrade:
# Normal web request responses are disabled
self.last_code = 405
self.last_message = "405 Method Not Allowed"
elif self.file_only and not os.path.isfile(abspath):
self.send_response(404, "No such file")
elif self.no_parent and not abspath.startswith(self.webroot):
self.send_response(403, "Hidden resources")
else:
SimpleHTTPRequestHandler.do_GET(self)
def send_response(self, code, message=None):
# Save the status code
self.last_code = code
SimpleHTTPRequestHandler.send_response(self, code, message)
def log_message(self, f, *args):
# Save instead of printing
self.last_message = f % args

View File

@ -12,6 +12,10 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import signal, socket, optparse, time, os, sys, subprocess, logging import signal, socket, optparse, time, os, sys, subprocess, logging
try: from socketserver import ForkingMixIn
except: from SocketServer import ForkingMixIn
try: from http.server import HTTPServer
except: from BaseHTTPServer import HTTPServer
from select import select from select import select
import websocket import websocket
try: try:
@ -20,15 +24,7 @@ except:
from cgi import parse_qs from cgi import parse_qs
from urlparse import urlparse from urlparse import urlparse
class WebSocketProxy(websocket.WebSocketServer): class ProxyRequestHandler(websocket.WebSocketRequestHandler):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
encoded/decoded to allow binary data to be sent/received to/from
the target.
"""
buffer_size = 65536
traffic_legend = """ traffic_legend = """
Traffic Legend: Traffic Legend:
@ -42,35 +38,31 @@ Traffic Legend:
<. - Client send partial <. - Client send partial
""" """
#
# Routines below this point are connection handler routines and
# will be run in a separate forked process for each connection.
#
def new_websocket_client(self): def new_websocket_client(self):
""" """
Called after a new WebSocket connection has been established. Called after a new WebSocket connection has been established.
""" """
# Checks if we receive a token, and look # Checks if we receive a token, and look
# for a valid target for it then # for a valid target for it then
if self.target_cfg: if self.server.target_cfg:
(self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path)
# Connect to the target # Connect to the target
if self.wrap_cmd: if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
elif self.unix_target: elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.unix_target msg = "connecting to unix socket: %s" % self.server.unix_target
else: else:
msg = "connecting to: %s:%s" % ( msg = "connecting to: %s:%s" % (
self.target_host, self.target_port) self.server.target_host, self.server.target_port)
if self.ssl_target: if self.server.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
self.msg(msg) self.log_message(msg)
tsock = self.socket(self.target_host, self.target_port, tsock = websocket.WebSocketServer.socket(self.server.target_host,
connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) self.server.target_port,
connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target)
self.print_traffic(self.traffic_legend) self.print_traffic(self.traffic_legend)
@ -81,8 +73,9 @@ Traffic Legend:
if tsock: if tsock:
tsock.shutdown(socket.SHUT_RDWR) tsock.shutdown(socket.SHUT_RDWR)
tsock.close() tsock.close()
self.vmsg("%s:%s: Closed target" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Closed target" %(
self.server.target_host, self.server.target_port))
raise raise
def get_target(self, target_cfg, path): def get_target(self, target_cfg, path):
@ -154,8 +147,9 @@ Traffic Legend:
if closed: if closed:
# TODO: What about blocking on client socket? # TODO: What about blocking on client socket?
self.vmsg("%s:%s: Client closed connection" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Client closed connection" %(
self.server.target_host, self.server.target_port))
raise self.CClose(closed['code'], closed['reason']) raise self.CClose(closed['code'], closed['reason'])
@ -175,20 +169,25 @@ Traffic Legend:
# Receive target data, encode it and queue for client # Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size) buf = target.recv(self.buffer_size)
if len(buf) == 0: if len(buf) == 0:
self.vmsg("%s:%s: Target closed connection" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Target closed connection" %(
self.server.target_host, self.server.target_port))
raise self.CClose(1000, "Target closed") raise self.CClose(1000, "Target closed")
cqueue.append(buf) cqueue.append(buf)
self.print_traffic("{") self.print_traffic("{")
class WebSocketProxy(websocket.WebSocketServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
encoded/decoded to allow binary data to be sent/received to/from
the target.
"""
# buffer_size = 65536
# Routines below this point are run in the master listener
# process.
#
def __init__(self, *args, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options # Save off proxy specific options
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop('target_port', None)
@ -229,7 +228,7 @@ Traffic Legend:
"REBIND_OLD_PORT": str(kwargs['listen_port']), "REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)}) "REBIND_NEW_PORT": str(self.target_port)})
websocket.WebSocketServer.__init__(self, *args, **kwargs) websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs)
def run_wrap_cmd(self): def run_wrap_cmd(self):
self.msg("Starting '%s'", " ".join(self.wrap_cmd)) self.msg("Starting '%s'", " ".join(self.wrap_cmd))
@ -358,6 +357,8 @@ def websockify_init():
help="Configuration file containing valid targets " help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a " "in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form") "directory containing configuration files of this form")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
if opts.verbose: if opts.verbose:
@ -406,8 +407,65 @@ def websockify_init():
opts.target_cfg = os.path.abspath(opts.target_cfg) opts.target_cfg = os.path.abspath(opts.target_cfg)
# Create and start the WebSockets proxy # Create and start the WebSockets proxy
server = WebSocketProxy(**opts.__dict__) libserver = opts.libserver
server.start_server() del opts.libserver
if libserver:
# Use standard Python SocketServer framework
server = LibProxyServer(**opts.__dict__)
server.serve_forever()
else:
# Use internal service framework
server = WebSocketProxy(**opts.__dict__)
server.start_server()
class LibProxyServer(ForkingMixIn, HTTPServer):
"""
Just like WebSocketProxy, but uses standard Python SocketServer
framework.
"""
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.target_cfg = kwargs.pop('target_cfg', None)
self.daemon = False
self.target_cfg = None
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.handler_id = 0
for arg in kwargs.keys():
print("warning: option %s ignored when using --libserver" % arg)
if web:
os.chdir(web)
HTTPServer.__init__(self, (listen_host, listen_port),
RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
ForkingMixIn.process_request(self, request, client_address)
if __name__ == '__main__': if __name__ == '__main__':
websockify_init() websockify_init()