diff --git a/websockify/websocket.py b/websockify/websocket.py index 5a18301..68dc53e 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -64,26 +64,32 @@ if multiprocessing and sys.platform == 'win32': import multiprocessing.reduction -class WebSocketServer(object): - """ - WebSockets server class. - Must be sub-classed with new_client method definition. - """ - +# HTTP handler with WebSocket upgrade support +class WebSocketRequestHandler(SimpleHTTPRequestHandler): buffer_size = 65536 GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r -Upgrade: websocket\r -Connection: Upgrade\r -Sec-WebSocket-Accept: %s\r -""" + server_version = "WebSockify" + + protocol_version = "HTTP/1.1" # An exception while the WebSocket client was connected class CClose(Exception): pass + def __init__(self, req, addr, server): + # Retrieve a few configuration variables from the server + self.only_upgrade = getattr(server, "only_upgrade", False) + self.verbose = getattr(server, "verbose", False) + self.daemon = getattr(server, "daemon", False) + self.record = getattr(server, "record", False) + self.run_once = getattr(server, "run_once", False) + self.rec = None + self.handler_id = getattr(server, "handler_id", False) + + SimpleHTTPRequestHandler.__init__(self, req, addr, server) + @staticmethod def unmask(buf, hlen, plen): pstart = hlen + 4 @@ -204,7 +210,7 @@ Sec-WebSocket-Accept: %s\r # Process 1 frame if f['masked']: # unmask payload - f['payload'] = WebSocketServer.unmask(buf, f['hlen'], + f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], f['length']) else: print("Unmasked frame: %s" % repr(buf)) @@ -227,9 +233,18 @@ Sec-WebSocket-Accept: %s\r return f # - # Main WebSocketServer methods + # WebSocketRequestHandler logging/output functions # + def traffic(self, token="."): + """ Show traffic flow in verbose mode. """ + if self.verbose and not self.daemon: + sys.stdout.write(token) + sys.stdout.flush() + + # + # Main WebSocketRequestHandler methods + # def send_frames(self, bufs=None): """ Encode and send WebSocket frames. Any frames already queued will be sent first. If buf is not set then only queued @@ -311,7 +326,7 @@ Sec-WebSocket-Accept: %s\r start = frame['hlen'] end = frame['hlen'] + frame['length'] if frame['masked']: - recbuf = WebSocketServer.unmask(buf, frame['hlen'], + recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'], frame['length']) else: recbuf = buf[frame['hlen']:frame['hlen'] + @@ -336,9 +351,8 @@ Sec-WebSocket-Accept: %s\r buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) self.request.send(buf) - def do_websocket_handshake(self, headers, path): - h = self.headers = headers - self.path = path + def do_websocket_handshake(self): + h = self.headers prot = 'WebSocket-Protocol' protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') @@ -353,7 +367,8 @@ Sec-WebSocket-Accept: %s\r if ver in ['7', '8', '13']: self.version = "hybi-%02d" % int(ver) else: - raise self.EClose('Unsupported protocol version %s' % ver) + self.send_error(400, "Unsupported protocol version %s" % ver) + return False key = h['Sec-WebSocket-Key'] @@ -363,26 +378,130 @@ Sec-WebSocket-Accept: %s\r elif 'base64' in protocols: self.base64 = True else: - raise self.EClose("Client must support 'binary' or 'base64' protocol") + self.send_error(400, "Client must support 'binary' or 'base64' protocol") + return False # Generate the hash value for the accept header accept = b64encode(sha1(s2b(key + self.GUID)).digest()) - response = self.server_handshake_hybi % b2s(accept) + self.send_response(101, "Switching Protocols") + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", b2s(accept)) if self.base64: - response += "Sec-WebSocket-Protocol: base64\r\n" + self.send_header("Sec-WebSocket-Protocol", "base64") else: - response += "Sec-WebSocket-Protocol: binary\r\n" - response += "\r\n" + self.send_header("Sec-WebSocket-Protocol", "binary") + self.end_headers() + return True + else: + self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + return False + + def handle_websocket(self): + """Upgrade a connection to Websocket, if requested. If this succeeds, + new_websocket_client() will be called. Otherwise, False is returned. + """ + if (self.headers.get('upgrade') and + self.headers.get('upgrade').lower() == 'websocket'): + + if not self.do_websocket_handshake(): + return False + + # Indicate to server that a Websocket upgrade was done + self.server.ws_connection = True + # Initialize per client settings + self.send_parts = [] + self.recv_part = None + self.start_time = int(time.time()*1000) + + # client_address is empty with, say, UNIX domain sockets + client_addr = "" + is_ssl = False + try: + client_addr = self.client_address[0] + is_ssl = self.client_address[2] + except IndexError: + pass + + if is_ssl: + self.stype = "SSL/TLS (wss://)" + else: + self.stype = "Plain non-SSL (ws://)" + + self.log_message("%s: %s WebSocket connection" % (client_addr, + self.stype)) + self.log_message("%s: Version %s, base64: '%s'" % (client_addr, + self.version, self.base64)) + if self.path != '/': + self.log_message("%s: Path: '%s'" % (client_addr, self.path)) + + if self.record: + # Record raw frame data as JavaScript array + fname = "%s.%s" % (self.record, + self.handler_id) + self.log_message("opening record file: %s" % fname) + self.rec = open(fname, 'w+') + encoding = "binary" + if self.base64: encoding = "base64" + self.rec.write("var VNC_frame_encoding = '%s';\n" + % encoding) + self.rec.write("var VNC_frame_data = [\n") + + try: + self.new_websocket_client() + except self.CClose: + # Close the client + _, exc, _ = sys.exc_info() + self.send_close(exc.args[0], exc.args[1]) + return True else: - raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + return False - return response - - def new_client(self): + def do_GET(self): + """Handle GET request. Calls handle_websocket(). If unsuccessful, + and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" + if not self.handle_websocket(): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_GET(self) + + def new_websocket_client(self): """ Do something with a WebSockets client connection. """ - raise("WebSocketServer.new_client() must be overloaded") + raise("WebSocketRequestHandler.new_websocket_client() must be overloaded") + + def do_HEAD(self): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_HEAD(self) + + def finish(self): + if self.rec: + self.rec.write("'EOF'];\n") + self.rec.close() + + def handle(self): + # When using run_once, we have a single process, so + # we cannot loop in BaseHTTPRequestHandler.handle; we + # must return and handle new connections + if self.run_once: + self.handle_one_request() + else: + SimpleHTTPRequestHandler.handle(self) + + def log_request(self, code='-', size='-'): + if self.verbose: + SimpleHTTPRequestHandler.log_request(self, code, size) + + +class WebSocketServer(object): + """ + WebSockets server class. + Must be sub-classed with new_client method definition. + """ policy_response = """\n""" @@ -390,12 +509,14 @@ Sec-WebSocket-Accept: %s\r class EClose(Exception): pass - def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, + def __init__(self, RequestHandlerClass, listen_host='', + listen_port=None, source_is_ipv6=False, verbose=False, cert='', key='', ssl_only=None, daemon=False, record='', web='', run_once=False, timeout=0, idle_timeout=0): # settings + self.RequestHandlerClass = RequestHandlerClass self.verbose = verbose self.listen_host = listen_host self.listen_port = listen_port @@ -422,6 +543,7 @@ Sec-WebSocket-Accept: %s\r if self.web: os.chdir(self.web) + self.only_upgrade = not self.web # Sanity checks if not ssl and self.ssl_only: @@ -548,7 +670,6 @@ Sec-WebSocket-Accept: %s\r - Send a WebSockets handshake server response. - Return the socket for this WebSocket client. """ - stype = "" ready = select.select([sock], [], [], 3)[0] @@ -592,42 +713,19 @@ Sec-WebSocket-Accept: %s\r else: raise - self.scheme = "wss" - stype = "SSL/TLS (wss://)" - elif self.ssl_only: raise self.EClose("non-SSL connection received but disallowed") else: retsock = sock - self.scheme = "ws" - stype = "Plain non-SSL (ws://)" - wsh = WSRequestHandler(retsock, address, not self.web) - if wsh.last_code == 101: - # Continue on to handle WebSocket upgrade - pass - elif wsh.last_code == 405: - raise self.EClose("Normal web request received but disallowed") - elif wsh.last_code < 200 or wsh.last_code >= 300: - raise self.EClose(wsh.last_message) - elif self.verbose: - raise self.EClose(wsh.last_message) - else: - raise self.EClose("") + # If the address is like (host, port), we are extending it + # with a flag indicating SSL. Not many other options + # available... + if len(address) == 2: + address = (address[0], address[1], (retsock != sock)) - response = self.do_websocket_handshake(wsh.headers, wsh.path) - - self.msg("%s: %s WebSocket connection" % (address[0], stype)) - self.msg("%s: Version %s, base64: '%s'" % (address[0], - self.version, self.base64)) - if self.path != '/': - self.msg("%s: Path: '%s'" % (address[0], self.path)) - - - # Send server WebSockets handshake response - #self.msg("sending response [%s]" % response) - retsock.send(s2b(response)) + self.RequestHandlerClass(retsock, address, self) # Return the WebSockets socket which may be SSL wrapped return retsock @@ -636,13 +734,6 @@ Sec-WebSocket-Accept: %s\r # # WebSocketServer logging/output functions # - - def traffic(self, token="."): - """ Show traffic flow in verbose mode. """ - if self.verbose and not self.daemon: - sys.stdout.write(token) - sys.stdout.flush() - def msg(self, msg): """ Output message with handler_id prefix. """ if not self.daemon: @@ -683,37 +774,11 @@ Sec-WebSocket-Accept: %s\r def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # Initialize per client settings - self.send_parts = [] - self.recv_part = None - self.base64 = False - self.rec = None - self.start_time = int(time.time()*1000) - # handler process + client = None try: try: - self.request = self.do_handshake(startsock, address) - - if self.record: - # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) - self.msg("opening record file: %s" % fname) - self.rec = open(fname, 'w+') - encoding = "binary" - if self.base64: encoding = "base64" - self.rec.write("var VNC_frame_encoding = '%s';\n" - % encoding) - self.rec.write("var VNC_frame_data = [\n") - - self.ws_connection = True - self.new_client() - except self.CClose: - # Close the client - _, exc, _ = sys.exc_info() - if self.request: - self.send_close(exc.args[0], exc.args[1]) + client = self.do_handshake(startsock, address) except self.EClose: _, exc, _ = sys.exc_info() # Connection was not a WebSockets connection @@ -725,14 +790,11 @@ Sec-WebSocket-Accept: %s\r if self.verbose: self.msg(traceback.format_exc()) finally: - if self.rec: - self.rec.write("'EOF'];\n") - self.rec.close() - if self.request and self.request != startsock: + if client and client != startsock: # Close the SSL wrapped socket # Original socket closed by caller - self.request.close() + client.close() def start_server(self): """ @@ -758,7 +820,6 @@ Sec-WebSocket-Accept: %s\r while True: try: try: - self.request = None startsock = None pid = err = 0 child_count = 0 @@ -855,33 +916,3 @@ Sec-WebSocket-Accept: %s\r self.vmsg("Closing socket listening at %s:%s" % (self.listen_host, self.listen_port)) lsock.close() - - -# HTTP handler with WebSocket upgrade support -class WSRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, req, addr, only_upgrade=False): - self.only_upgrade = only_upgrade # only allow upgrades - SimpleHTTPRequestHandler.__init__(self, req, addr, object()) - - def do_GET(self): - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): - - # Just indicate that an WebSocket upgrade is needed - self.last_code = 101 - self.last_message = "101 Switching Protocols" - elif self.only_upgrade: - # Normal web request responses are disabled - self.last_code = 405 - self.last_message = "405 Method Not Allowed" - else: - SimpleHTTPRequestHandler.do_GET(self) - - def send_response(self, code, message=None): - # Save the status code - self.last_code = code - SimpleHTTPRequestHandler.send_response(self, code, message) - - def log_message(self, f, *args): - # Save instead of printing - self.last_message = f % args diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 928dde9..f49a497 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -12,6 +12,8 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' import signal, socket, optparse, time, os, sys, subprocess +import SocketServer, BaseHTTPServer +from SimpleHTTPServer import SimpleHTTPRequestHandler from select import select import websocket try: @@ -20,7 +22,7 @@ except: from cgi import parse_qs from urlparse import urlparse -class WebSocketProxy(websocket.WebSocketServer): +class CustomProxyServer(websocket.WebSocketServer): """ Proxy traffic to and from a WebSockets client to a normal TCP socket server target. All traffic to/from the client is base64 @@ -140,6 +142,8 @@ class WebSocketProxy(websocket.WebSocketServer): # process. # +class ProxyRequestHandler(websocket.WebSocketRequestHandler): + # # Routines below this point are connection handler routines and # will be run in a separate forked process for each connection. @@ -150,38 +154,38 @@ Traffic Legend: } - Client receive }. - Client receive partial { - Target receive - + > - Target send >. - Target send partial < - Client send <. - Client send partial """ - def new_client(self): + def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ # Checks if we receive a token, and look # for a valid target for it then - if self.target_cfg: - (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) + if self.server.target_cfg: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path) # Connect to the target - if self.wrap_cmd: - msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - msg = "connecting to unix socket: %s" % self.unix_target + if self.server.wrap_cmd: + msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) + elif self.server.unix_target: + msg = "connecting to unix socket: %s" % self.server.unix_target else: msg = "connecting to: %s:%s" % ( - self.target_host, self.target_port) + self.server.target_host, self.server.target_port) - if self.ssl_target: + if self.server.ssl_target: msg += " (using SSL)" - self.msg(msg) + self.log_message(msg) - tsock = websocket.WebSocketServer.socket(self.target_host, - self.target_port, - connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) + tsock = websocket.WebSocketServer.socket(self.server.target_host, + self.server.target_port, + connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target) if self.verbose and not self.daemon: print(self.traffic_legend) @@ -193,8 +197,9 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() - self.vmsg("%s:%s: Closed target" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Closed target" %( + self.server.target_host, self.server.target_port)) raise def get_target(self, target_cfg, path): @@ -266,8 +271,9 @@ Traffic Legend: if closed: # TODO: What about blocking on client socket? - self.vmsg("%s:%s: Client closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Client closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(closed['code'], closed['reason']) @@ -287,8 +293,9 @@ Traffic Legend: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: - self.vmsg("%s:%s: Target closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Target closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(1000, "Target closed") cqueue.append(buf) @@ -346,6 +353,8 @@ def websockify_init(): help="Configuration file containing valid targets " "in the form 'token: host:port' or, alternatively, a " "directory containing configuration files of this form") + parser.add_option("--libserver", action="store_true", + help="use Python library SocketServer engine") (opts, args) = parser.parse_args() # Sanity checks @@ -387,8 +396,65 @@ def websockify_init(): except: parser.error("Error parsing target port") # Create and start the WebSockets proxy - server = WebSocketProxy(**opts.__dict__) - server.start_server() + libserver = opts.libserver + del opts.libserver + if libserver: + # Use standard Python SocketServer framework + httpd = LibProxyServer(ProxyRequestHandler, **opts.__dict__) + httpd.serve_forever() + else: + # Use internal service framework + server = CustomProxyServer(ProxyRequestHandler, **opts.__dict__) + server.start_server() + + +class LibProxyServer(SocketServer.ForkingMixIn, BaseHTTPServer.HTTPServer): + """ + Just like CustomProxyServer, but uses standard Python SocketServer + framework. + """ + + def __init__(self, RequestHandlerClass, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + self.daemon = False + self.target_cfg = None + + # Server configuration + listen_host = kwargs.pop('listen_host', '') + listen_port = kwargs.pop('listen_port', None) + web = kwargs.pop('web', '') + + # Configuration affecting base request handler + self.only_upgrade = not web + self.verbose = kwargs.pop('verbose', False) + record = kwargs.pop('record', '') + if record: + self.record = os.path.abspath(record) + self.run_once = kwargs.pop('run_once', False) + self.handler_id = 0 + + for arg in kwargs.keys(): + print("warning: option %s ignored when using --libserver" % arg) + + if web: + os.chdir(web) + + BaseHTTPServer.HTTPServer.__init__(self, (listen_host, listen_port), + RequestHandlerClass) + + + def process_request(self, request, client_address): + """Override process_request to implement a counter""" + self.handler_id += 1 + SocketServer.ForkingMixIn.process_request(self, request, client_address) + if __name__ == '__main__': websockify_init()