From 623aec262bcd1f72840e20b55fafd490334f6d64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A1n=20Jockusch?= Date: Thu, 2 May 2019 07:45:36 +0200 Subject: [PATCH] Add a plugin callback for connect and disconnect events --- websockify/conn_plugins.py | 36 +++++++++++++++++++ websockify/websocketproxy.py | 68 +++++++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 websockify/conn_plugins.py diff --git a/websockify/conn_plugins.py b/websockify/conn_plugins.py new file mode 100644 index 0000000..afbd321 --- /dev/null +++ b/websockify/conn_plugins.py @@ -0,0 +1,36 @@ +class BasePlugin(object): + def __init__(self, src=None): + self.source = src + + def connect( + self, host, port, use_ssl, unix_socket, + sockname, query): + """When a socket connection begins, this call receives as much + information as possible. Especially sockname is important, because + this allows "disconnect" to tell the connections apart.""" + pass + + def disconnect( + self, host, port, use_ssl, unix_socket, + sockname, query): + """This function is called with the exact same parameters as + "connect", but when the connection closes.""" + pass + + +class DebugPlugin(BasePlugin): + """Prints out event information for connections and disconnections.""" + + def connect(self, host, port, use_ssl, unix_socket, sockname, query): + print([ + 'conn_plugin.DebugPlugin connect', + host, port, use_ssl, unix_socket, + sockname, query, + ]) + + def disconnect(self, host, port, use_ssl, unix_socket, sockname, query): + print([ + 'conn_plugin.DebugPlugin disconnect', + host, port, use_ssl, unix_socket, + sockname, query, + ]) diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index d3c7130..66727a7 100644 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -46,15 +46,15 @@ Traffic Legend: < - Client send <. - Client send partial """ - + def send_auth_error(self, ex): self.send_response(ex.code, ex.msg) self.send_header('Content-Type', 'text/html') for name, val in ex.headers.items(): self.send_header(name, val) - + self.end_headers() - + def validate_connection(self): if not self.server.token_plugin: return @@ -83,7 +83,7 @@ Traffic Legend: except (TypeError, AttributeError, KeyError): # not a SSL connection or client presented no certificate with valid data pass - + try: self.server.auth_plugin.authenticate( headers=self.headers, target_host=self.server.target_host, @@ -129,6 +129,22 @@ Traffic Legend: self.print_traffic(self.traffic_legend) + # Here we call the connection hook "connect", if one is defined. + # These local variables are used by the connection tracking: + sockname, query = None, None + if self.server.conn_plugin: + # Store the local socket connection data + sockname = tsock.getsockname() + query = parse_qs(urlparse(self.path).query) + self.server.conn_plugin.connect( + host=self.server.target_host, + port=self.server.target_port, + use_ssl=self.server.ssl_target, + unix_socket=self.server.unix_target, + sockname=sockname, + query=query, + ) + # Start proxying try: self.do_proxy(tsock) @@ -136,6 +152,19 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() + + # After disconnecting, we call the "disconnect" hook of the + # connection plugin, if it exists. + if self.server.conn_plugin: + self.server.conn_plugin.disconnect( + host=self.server.target_host, + port=self.server.target_port, + use_ssl=self.server.ssl_target, + unix_socket=self.server.unix_target, + sockname=sockname, + query=query, + ) + if self.verbose: self.log_message("%s:%s: Closed target", self.server.target_host, self.server.target_port) @@ -285,6 +314,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): self.token_plugin = kwargs.pop('token_plugin', None) self.host_token = kwargs.pop('host_token', None) self.auth_plugin = kwargs.pop('auth_plugin', None) + self.conn_plugin = kwargs.pop('conn_plugin', None) # Last 3 timestamps command was run self.wrap_times = [0, 0, 0] @@ -353,6 +383,10 @@ class WebSocketProxy(websockifyserver.WebSockifyServer): msg = " - proxying from %s to %s" % ( src_string, dst_string) + if self.conn_plugin: + msg = " - Tracking connections with plugin %s" % ( + type(self.conn_plugin).__name__) + if self.ssl_target: msg += " (using SSL)" @@ -433,7 +467,7 @@ def select_ssl_version(version): # It so happens that version names sorted lexicographically form a list # from the least to the most secure keys = list(SSL_OPTIONS.keys()) - keys.sort() + keys.sort() fallback = keys[-1] logger = logging.getLogger(WebSocketProxy.log_prefix) logger.warn("TLS version %s unsupported. Falling back to %s", @@ -536,6 +570,12 @@ def websockify_init(): parser.add_option("--auth-source", default=None, metavar="ARG", help="an argument to be passed to the auth plugin " "on instantiation") + parser.add_option("--conn-plugin", default=None, metavar="CLASS", + help="use a Python class to implement hooks on " + "connection events") + parser.add_option("--conn-source", default=None, metavar="ARG", + help="an argument to be passed to the conn plugin " + "on instantiation") parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL", help="send a ping to the client every INTERVAL seconds") parser.add_option("--log-file", metavar="FILE", @@ -568,6 +608,9 @@ def websockify_init(): if opts.web_auth and not opts.web: parser.error("You must use --web to use --web-auth") + if opts.conn_source and not opts.conn_plugin: + parser.error("You must use --conn-plugin to use --conn-source") + if opts.legacy_syslog and not opts.syslog: parser.error("You must use --syslog to use --legacy-syslog") @@ -713,6 +756,21 @@ def websockify_init(): del opts.auth_source + if opts.conn_plugin is not None: + if '.' not in opts.conn_plugin: + opts.conn_plugin = ( + 'websockify.conn_plugins.%s' % opts.conn_plugin) + + conn_plugin_module, conn_plugin_cls = opts.conn_plugin.rsplit('.', 1) + + __import__(conn_plugin_module) + conn_plugin_cls = getattr(sys.modules[conn_plugin_module], conn_plugin_cls) + + opts.conn_plugin = conn_plugin_cls(opts.conn_source) + + del opts.conn_source + + # Create and start the WebSockets proxy libserver = opts.libserver del opts.libserver