Add a plugin callback for connect and disconnect events
This commit is contained in:
parent
6d48b1507e
commit
623aec262b
|
|
@ -0,0 +1,36 @@
|
||||||
|
class BasePlugin(object):
|
||||||
|
def __init__(self, src=None):
|
||||||
|
self.source = src
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self, host, port, use_ssl, unix_socket,
|
||||||
|
sockname, query):
|
||||||
|
"""When a socket connection begins, this call receives as much
|
||||||
|
information as possible. Especially sockname is important, because
|
||||||
|
this allows "disconnect" to tell the connections apart."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def disconnect(
|
||||||
|
self, host, port, use_ssl, unix_socket,
|
||||||
|
sockname, query):
|
||||||
|
"""This function is called with the exact same parameters as
|
||||||
|
"connect", but when the connection closes."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DebugPlugin(BasePlugin):
|
||||||
|
"""Prints out event information for connections and disconnections."""
|
||||||
|
|
||||||
|
def connect(self, host, port, use_ssl, unix_socket, sockname, query):
|
||||||
|
print([
|
||||||
|
'conn_plugin.DebugPlugin connect',
|
||||||
|
host, port, use_ssl, unix_socket,
|
||||||
|
sockname, query,
|
||||||
|
])
|
||||||
|
|
||||||
|
def disconnect(self, host, port, use_ssl, unix_socket, sockname, query):
|
||||||
|
print([
|
||||||
|
'conn_plugin.DebugPlugin disconnect',
|
||||||
|
host, port, use_ssl, unix_socket,
|
||||||
|
sockname, query,
|
||||||
|
])
|
||||||
|
|
@ -129,6 +129,22 @@ Traffic Legend:
|
||||||
|
|
||||||
self.print_traffic(self.traffic_legend)
|
self.print_traffic(self.traffic_legend)
|
||||||
|
|
||||||
|
# Here we call the connection hook "connect", if one is defined.
|
||||||
|
# These local variables are used by the connection tracking:
|
||||||
|
sockname, query = None, None
|
||||||
|
if self.server.conn_plugin:
|
||||||
|
# Store the local socket connection data
|
||||||
|
sockname = tsock.getsockname()
|
||||||
|
query = parse_qs(urlparse(self.path).query)
|
||||||
|
self.server.conn_plugin.connect(
|
||||||
|
host=self.server.target_host,
|
||||||
|
port=self.server.target_port,
|
||||||
|
use_ssl=self.server.ssl_target,
|
||||||
|
unix_socket=self.server.unix_target,
|
||||||
|
sockname=sockname,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
# Start proxying
|
# Start proxying
|
||||||
try:
|
try:
|
||||||
self.do_proxy(tsock)
|
self.do_proxy(tsock)
|
||||||
|
|
@ -136,6 +152,19 @@ Traffic Legend:
|
||||||
if tsock:
|
if tsock:
|
||||||
tsock.shutdown(socket.SHUT_RDWR)
|
tsock.shutdown(socket.SHUT_RDWR)
|
||||||
tsock.close()
|
tsock.close()
|
||||||
|
|
||||||
|
# After disconnecting, we call the "disconnect" hook of the
|
||||||
|
# connection plugin, if it exists.
|
||||||
|
if self.server.conn_plugin:
|
||||||
|
self.server.conn_plugin.disconnect(
|
||||||
|
host=self.server.target_host,
|
||||||
|
port=self.server.target_port,
|
||||||
|
use_ssl=self.server.ssl_target,
|
||||||
|
unix_socket=self.server.unix_target,
|
||||||
|
sockname=sockname,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.log_message("%s:%s: Closed target",
|
self.log_message("%s:%s: Closed target",
|
||||||
self.server.target_host, self.server.target_port)
|
self.server.target_host, self.server.target_port)
|
||||||
|
|
@ -285,6 +314,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
||||||
self.token_plugin = kwargs.pop('token_plugin', None)
|
self.token_plugin = kwargs.pop('token_plugin', None)
|
||||||
self.host_token = kwargs.pop('host_token', None)
|
self.host_token = kwargs.pop('host_token', None)
|
||||||
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
self.auth_plugin = kwargs.pop('auth_plugin', None)
|
||||||
|
self.conn_plugin = kwargs.pop('conn_plugin', None)
|
||||||
|
|
||||||
# Last 3 timestamps command was run
|
# Last 3 timestamps command was run
|
||||||
self.wrap_times = [0, 0, 0]
|
self.wrap_times = [0, 0, 0]
|
||||||
|
|
@ -353,6 +383,10 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
|
||||||
msg = " - proxying from %s to %s" % (
|
msg = " - proxying from %s to %s" % (
|
||||||
src_string, dst_string)
|
src_string, dst_string)
|
||||||
|
|
||||||
|
if self.conn_plugin:
|
||||||
|
msg = " - Tracking connections with plugin %s" % (
|
||||||
|
type(self.conn_plugin).__name__)
|
||||||
|
|
||||||
if self.ssl_target:
|
if self.ssl_target:
|
||||||
msg += " (using SSL)"
|
msg += " (using SSL)"
|
||||||
|
|
||||||
|
|
@ -536,6 +570,12 @@ def websockify_init():
|
||||||
parser.add_option("--auth-source", default=None, metavar="ARG",
|
parser.add_option("--auth-source", default=None, metavar="ARG",
|
||||||
help="an argument to be passed to the auth plugin "
|
help="an argument to be passed to the auth plugin "
|
||||||
"on instantiation")
|
"on instantiation")
|
||||||
|
parser.add_option("--conn-plugin", default=None, metavar="CLASS",
|
||||||
|
help="use a Python class to implement hooks on "
|
||||||
|
"connection events")
|
||||||
|
parser.add_option("--conn-source", default=None, metavar="ARG",
|
||||||
|
help="an argument to be passed to the conn plugin "
|
||||||
|
"on instantiation")
|
||||||
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
|
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
|
||||||
help="send a ping to the client every INTERVAL seconds")
|
help="send a ping to the client every INTERVAL seconds")
|
||||||
parser.add_option("--log-file", metavar="FILE",
|
parser.add_option("--log-file", metavar="FILE",
|
||||||
|
|
@ -568,6 +608,9 @@ def websockify_init():
|
||||||
if opts.web_auth and not opts.web:
|
if opts.web_auth and not opts.web:
|
||||||
parser.error("You must use --web to use --web-auth")
|
parser.error("You must use --web to use --web-auth")
|
||||||
|
|
||||||
|
if opts.conn_source and not opts.conn_plugin:
|
||||||
|
parser.error("You must use --conn-plugin to use --conn-source")
|
||||||
|
|
||||||
if opts.legacy_syslog and not opts.syslog:
|
if opts.legacy_syslog and not opts.syslog:
|
||||||
parser.error("You must use --syslog to use --legacy-syslog")
|
parser.error("You must use --syslog to use --legacy-syslog")
|
||||||
|
|
||||||
|
|
@ -713,6 +756,21 @@ def websockify_init():
|
||||||
|
|
||||||
del opts.auth_source
|
del opts.auth_source
|
||||||
|
|
||||||
|
if opts.conn_plugin is not None:
|
||||||
|
if '.' not in opts.conn_plugin:
|
||||||
|
opts.conn_plugin = (
|
||||||
|
'websockify.conn_plugins.%s' % opts.conn_plugin)
|
||||||
|
|
||||||
|
conn_plugin_module, conn_plugin_cls = opts.conn_plugin.rsplit('.', 1)
|
||||||
|
|
||||||
|
__import__(conn_plugin_module)
|
||||||
|
conn_plugin_cls = getattr(sys.modules[conn_plugin_module], conn_plugin_cls)
|
||||||
|
|
||||||
|
opts.conn_plugin = conn_plugin_cls(opts.conn_source)
|
||||||
|
|
||||||
|
del opts.conn_source
|
||||||
|
|
||||||
|
|
||||||
# Create and start the WebSockets proxy
|
# Create and start the WebSockets proxy
|
||||||
libserver = opts.libserver
|
libserver = opts.libserver
|
||||||
del opts.libserver
|
del opts.libserver
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue