Add traffic plugins

This commit is contained in:
Gerhard Schaden 2023-02-26 14:17:12 +01:00
parent a1346552fb
commit 20b8c798c1
5 changed files with 40 additions and 2 deletions

View File

@ -110,6 +110,10 @@ These are not necessary for the basic operation.
options, where CLASS is usually one from token_plugins.py and ARG is
the plugin's configuration.
* Traffic plugins: an instance per connection to intercept and modify traffic
This functionality is activated with the `--traffic-plugin CLASS`, where
CLASS is an implementation of the traffic_plugin BasePlugin.
### Other implementations of websockify
The primary implementation of websockify is in python. There are

View File

@ -95,7 +95,7 @@ class JSONTokenApi(BaseTokenAPI):
def process_result(self, resp):
resp_json = resp.json()
return (resp_json['host'], resp_json['port'])
return (resp_json['host'], resp_json['port'], resp_json)
class JWTTokenApi(BasePlugin):

View File

@ -0,0 +1,8 @@
class BasePlugin():
def __init__(self, handler, tsock):
self.handler = handler
self.tsock = tsock
def from_client(self, s):
return s
def from_target(self, s):
return s

View File

@ -48,13 +48,15 @@ Traffic Legend:
if not self.server.token_plugin:
return
host, port = self.get_target(self.server.token_plugin)
host, port, *rest = self.get_target(self.server.token_plugin)
if host == 'unix_socket':
self.server.unix_target = port
else:
self.server.target_host = host
self.server.target_port = port
if len(rest) > 0: # rest is an array of all remaining parts
self.target_attribtues = rest[0]
def auth_connection(self):
if not self.server.auth_plugin:
@ -122,6 +124,8 @@ Traffic Legend:
# Start proxying
try:
if self.traffic_plugin: # replace class with instance
self.traffic_plugin = self.traffic_plugin(self, tsock)
self.do_proxy(tsock)
finally:
if tsock:
@ -246,6 +250,9 @@ Traffic Legend:
if target in outs:
# Send queued client data to the target
dat = tqueue.pop(0)
if self.traffic_plugin:
dat = self.traffic_plugin.from_client(dat)
if not dat: continue
sent = target.send(dat)
if sent == len(dat):
self.print_traffic(">")
@ -258,6 +265,9 @@ Traffic Legend:
if target in ins:
# Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size)
if self.traffic_plugin:
buf = self.traffic_plugin.from_target(buf)
if not buf: continue
if len(buf) == 0:
# Target socket closed, flushing queues and closing client-side websocket
@ -298,6 +308,7 @@ class WebSocketProxy(websockifyserver.WebSockifyServer):
self.token_plugin = kwargs.pop('token_plugin', None)
self.host_token = kwargs.pop('host_token', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.traffic_plugin = kwargs.pop('traffic_plugin', None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
@ -537,6 +548,9 @@ def websockify_init():
parser.add_option("--auth-plugin", default=None, metavar="CLASS",
help="use a Python class, usually one from websockify.auth_plugins, "
"such as BasicHTTPAuth, to determine if a connection is allowed")
parser.add_option("--traffic-plugin", default=None, metavar="CLASS",
help="use a Python class, usually one from websockify.traffic_plugins, "
"to modify the traffic")
parser.add_option("--auth-source", default=None, metavar="ARG",
help="an argument to be passed to the auth plugin "
"on instantiation")
@ -732,6 +746,17 @@ def websockify_init():
del opts.auth_source
if opts.traffic_plugin is not None:
if '.' not in opts.traffic_plugin:
opts.traffic_plugin = 'websockify.traffic_plugins.%s' % opts.traffic_plugin
traffic_plugin_module, traffic_plugin_cls = opts.traffic_plugin.rsplit('.', 1)
__import__(traffic_plugin_module)
traffic_plugin_cls = getattr(sys.modules[traffic_plugin_module], traffic_plugin_cls)
opts.traffic_plugin = traffic_plugin_cls
# Create and start the WebSockets proxy
libserver = opts.libserver
del opts.libserver

View File

@ -84,6 +84,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa
if self.logger is None:
self.logger = WebSockifyServer.get_logger()
self.traffic_plugin = getattr(server, "traffic_plugin", None)
super().__init__(req, addr, server)
def log_message(self, format, *args):