This commit is contained in:
Alon Bar-Lev 2013-10-14 06:07:23 -07:00
commit a0b16a3a6e
2 changed files with 172 additions and 126 deletions

View File

@ -16,7 +16,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import os, sys, time, errno, signal, socket, traceback, select import os, sys, time, errno, signal, socket, select, logging
import array, struct import array, struct
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
@ -70,6 +70,8 @@ class WebSocketServer(object):
Must be sub-classed with new_client method definition. Must be sub-classed with new_client method definition.
""" """
log_prefix = "novnc"
buffer_size = 65536 buffer_size = 65536
@ -98,6 +100,9 @@ Sec-WebSocket-Accept: %s\r
class CClose(Exception): class CClose(Exception):
pass pass
class Terminate(Exception):
pass
def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', ssl_only=None, verbose=False, cert='', key='', ssl_only=None,
daemon=False, record='', web='', daemon=False, record='', web='',
@ -118,6 +123,8 @@ Sec-WebSocket-Accept: %s\r
self.ws_connection = False self.ws_connection = False
self.handler_id = 1 self.handler_id = 1
self.logger = self.get_logger()
# Make paths settings absolute # Make paths settings absolute
self.cert = os.path.abspath(cert) self.cert = os.path.abspath(cert)
self.key = self.web = self.record = '' self.key = self.web = self.record = ''
@ -138,30 +145,35 @@ Sec-WebSocket-Accept: %s\r
raise Exception("Module 'resource' required to daemonize") raise Exception("Module 'resource' required to daemonize")
# Show configuration # Show configuration
print("WebSocket server settings:") self.logger.info("WebSocket server settings:")
print(" - Listen on %s:%s" % ( self.logger.info(" - Listen on %s:%s",
self.listen_host, self.listen_port)) self.listen_host, self.listen_port)
print(" - Flash security policy server") self.logger.info(" - Flash security policy server")
if self.web: if self.web:
print(" - Web server. Web root: %s" % self.web) self.logger.info(" - Web server. Web root: %s", self.web)
if ssl: if ssl:
if os.path.exists(self.cert): if os.path.exists(self.cert):
print(" - SSL/TLS support") self.logger.info(" - SSL/TLS support")
if self.ssl_only: if self.ssl_only:
print(" - Deny non-SSL/TLS connections") self.logger.info(" - Deny non-SSL/TLS connections")
else: else:
print(" - No SSL/TLS support (no cert file)") self.logger.info(" - No SSL/TLS support (no cert file)")
else: else:
print(" - No SSL/TLS support (no 'ssl' module)") self.logger.info(" - No SSL/TLS support (no 'ssl' module)")
if self.daemon: if self.daemon:
print(" - Backgrounding (daemon)") self.logger.info(" - Backgrounding (daemon)")
if self.record: if self.record:
print(" - Recording to '%s.*'" % self.record) self.logger.info(" - Recording to '%s.*'", self.record)
# #
# WebSocketServer static methods # WebSocketServer static methods
# #
@staticmethod
def get_logger():
return logging.getLogger("%s.%s" % (
WebSocketServer.log_prefix, WebSocketServer.__class__.__name__))
@staticmethod @staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False): def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False):
""" Resolve a host (and optional port) to an IPv4 or IPv6 """ Resolve a host (and optional port) to an IPv4 or IPv6
@ -219,8 +231,7 @@ Sec-WebSocket-Accept: %s\r
if os.fork() > 0: os._exit(0) # Parent exits if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling # Signal handling
def terminate(a,b): os._exit(0) signal.signal(signal.SIGTERM, signal.SIG_IGN)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files # Close open files
@ -324,6 +335,8 @@ Sec-WebSocket-Accept: %s\r
'close_code' : 1000, 'close_code' : 1000,
'close_reason' : ''} 'close_reason' : ''}
logger = WebSocketServer.get_logger()
blen = len(buf) blen = len(buf)
f['left'] = blen f['left'] = blen
@ -362,15 +375,16 @@ Sec-WebSocket-Accept: %s\r
f['payload'] = WebSocketServer.unmask(buf, f['hlen'], f['payload'] = WebSocketServer.unmask(buf, f['hlen'],
f['length']) f['length'])
else: else:
print("Unmasked frame: %s" % repr(buf)) logger.debug("Unmasked frame: %s", repr(buf))
f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len]
if base64 and f['opcode'] in [1, 2]: if base64 and f['opcode'] in [1, 2]:
try: try:
f['payload'] = b64decode(f['payload']) f['payload'] = b64decode(f['payload'])
except: except:
print("Exception while b64decoding buffer: %s" % logger.warning("Exception while b64decoding buffer: %s",
repr(buf)) repr(buf))
logger.debug('Exception', exc_info=True)
raise raise
if f['opcode'] == 0x08: if f['opcode'] == 0x08:
@ -414,20 +428,20 @@ Sec-WebSocket-Accept: %s\r
# #
def traffic(self, token="."): def traffic(self, token="."):
""" Show traffic flow in verbose mode. """ """ Show traffic flow mode. """
if self.verbose and not self.daemon: self.logger.debug("%s", token)
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg): def log(self, lvl, msg, *args, **kwargs):
self.logger.log(lvl, "% 3d: %s" % (self.handler_id, msg),
*args, **kwargs)
def msg(self, *args, **kwargs):
""" Output message with handler_id prefix. """ """ Output message with handler_id prefix. """
if not self.daemon: self.log(logging.INFO, *args, **kwargs)
print("% 3d: %s" % (self.handler_id, msg))
def vmsg(self, msg): def vmsg(self, *args, **kwargs):
""" Same as msg() but only if verbose. """ """ Same as msg() but as debug. """
if self.verbose: self.log(logging.DEBUG, *args, **kwargs)
self.msg(msg)
# #
# Main WebSocketServer methods # Main WebSocketServer methods
@ -752,6 +766,9 @@ Sec-WebSocket-Accept: %s\r
#self.vmsg("Running poll()") #self.vmsg("Running poll()")
pass pass
def terminate(self):
raise self.Terminate()
def fallback_SIGCHLD(self, sig, stack): def fallback_SIGCHLD(self, sig, stack):
# Reap zombies when using os.fork() (python 2.4) # Reap zombies when using os.fork() (python 2.4)
self.vmsg("Got SIGCHLD, reaping zombies") self.vmsg("Got SIGCHLD, reaping zombies")
@ -765,7 +782,11 @@ Sec-WebSocket-Accept: %s\r
def do_SIGINT(self, sig, stack): def do_SIGINT(self, sig, stack):
self.msg("Got SIGINT, exiting") self.msg("Got SIGINT, exiting")
sys.exit(0) self.terminate()
def do_SIGTERM(self, sig, stack):
self.msg("Got SIGTERM, exiting")
self.terminate()
def top_new_client(self, startsock, address): def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """ """ Do something with a WebSockets client connection. """
@ -808,8 +829,7 @@ Sec-WebSocket-Accept: %s\r
except Exception: except Exception:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc)) self.msg("handler exception: %s" % str(exc))
if self.verbose: self.vmsg("exception", exc_info=True)
self.msg(traceback.format_exc())
finally: finally:
if self.rec: if self.rec:
self.rec.write("'EOF'];\n") self.rec.write("'EOF'];\n")
@ -838,113 +858,125 @@ Sec-WebSocket-Accept: %s\r
self.started() # Some things need to happen after daemonizing self.started() # Some things need to happen after daemonizing
# Allow override of SIGINT # Allow override of signals
original_signals = {
signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
signal.SIGCHLD: signal.getsignal(signal.SIGCHLD),
}
signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM)
if not multiprocessing: if not multiprocessing:
# os.fork() (python 2.4) child reaper # os.fork() (python 2.4) child reaper
signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD) signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD)
last_active_time = self.launch_time last_active_time = self.launch_time
while True: try:
try: while True:
try: try:
self.client = None try:
startsock = None self.client = None
pid = err = 0 startsock = None
child_count = 0 pid = err = 0
child_count = 0
if multiprocessing and self.idle_timeout: if multiprocessing and self.idle_timeout:
child_count = len(multiprocessing.active_children()) child_count = len(multiprocessing.active_children())
time_elapsed = time.time() - self.launch_time time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout: if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s' self.msg('listener exit due to --timeout %s'
% self.timeout) % self.timeout)
break
if self.idle_timeout:
idle_time = 0
if child_count == 0:
idle_time = time.time() - last_active_time
else:
idle_time = 0
last_active_time = time.time()
if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s'
% self.idle_timeout)
break break
try: if self.idle_timeout:
self.poll() idle_time = 0
if child_count == 0:
idle_time = time.time() - last_active_time
else:
idle_time = 0
last_active_time = time.time()
ready = select.select([lsock], [], [], 1)[0] if idle_time > self.idle_timeout and child_count == 0:
if lsock in ready: self.msg('listener exit due to --idle-timeout %s'
startsock, address = lsock.accept() % self.idle_timeout)
break
try:
self.poll()
ready = select.select([lsock], [], [], 1)[0]
if lsock in ready:
startsock, address = lsock.accept()
else:
continue
except self.Terminate:
raise
except Exception:
_, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'):
err = exc.errno
elif hasattr(exc, 'args'):
err = exc.args[0]
else:
err = exc[0]
if err == errno.EINTR:
self.vmsg("Ignoring interrupted syscall")
continue
else:
raise
if self.run_once:
# Run in same process if run_once
self.top_new_client(startsock, address)
if self.ws_connection :
self.msg('%s: exiting due to --run-once'
% address[0])
break
elif multiprocessing:
self.vmsg('%s: new handler Process' % address[0])
p = multiprocessing.Process(
target=self.top_new_client,
args=(startsock, address))
p.start()
# child will not return
else: else:
continue # python 2.4
self.vmsg('%s: forking handler' % address[0])
pid = os.fork()
if pid == 0:
# child handler process
self.top_new_client(startsock, address)
break # child process exits
# parent process
self.handler_id += 1
except KeyboardInterrupt:
_, exc, _ = sys.exc_info()
print("In KeyboardInterrupt")
pass
except (self.Terminate, SystemExit):
_, exc, _ = sys.exc_info()
print("In SystemExit")
break
except Exception: except Exception:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'): self.msg("handler exception: %s", str(exc))
err = exc.errno self.vmsg("exception", exc_info=True)
elif hasattr(exc, 'args'):
err = exc.args[0]
else:
err = exc[0]
if err == errno.EINTR:
self.vmsg("Ignoring interrupted syscall")
continue
else:
raise
if self.run_once:
# Run in same process if run_once
self.top_new_client(startsock, address)
if self.ws_connection :
self.msg('%s: exiting due to --run-once'
% address[0])
break
elif multiprocessing:
self.vmsg('%s: new handler Process' % address[0])
p = multiprocessing.Process(
target=self.top_new_client,
args=(startsock, address))
p.start()
# child will not return
else:
# python 2.4
self.vmsg('%s: forking handler' % address[0])
pid = os.fork()
if pid == 0:
# child handler process
self.top_new_client(startsock, address)
break # child process exits
# parent process finally:
self.handler_id += 1 if startsock:
startsock.close()
finally:
# Close listen port
self.vmsg("Closing socket listening at %s:%s",
self.listen_host, self.listen_port)
lsock.close()
except KeyboardInterrupt: # Restore signals
_, exc, _ = sys.exc_info() for sig, func in original_signals.items():
print("In KeyboardInterrupt") signal.signal(sig, func)
pass
except SystemExit:
_, exc, _ = sys.exc_info()
print("In SystemExit")
break
except Exception:
_, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc))
if self.verbose:
self.msg(traceback.format_exc())
finally:
if startsock:
startsock.close()
# Close listen port
self.vmsg("Closing socket listening at %s:%s"
% (self.listen_host, self.listen_port))
lsock.close()
# HTTP handler with WebSocket upgrade support # HTTP handler with WebSocket upgrade support

View File

@ -11,7 +11,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import signal, socket, optparse, time, os, sys, subprocess import signal, socket, optparse, time, os, sys, subprocess, logging
from select import select from select import select
import websocket import websocket
try: try:
@ -86,7 +86,7 @@ Traffic Legend:
websocket.WebSocketServer.__init__(self, *args, **kwargs) websocket.WebSocketServer.__init__(self, *args, **kwargs)
def run_wrap_cmd(self): def run_wrap_cmd(self):
print("Starting '%s'" % " ".join(self.wrap_cmd)) self.logger.info("Starting '%s'", " ".join(self.wrap_cmd))
self.wrap_times.append(time.time()) self.wrap_times.append(time.time())
self.wrap_times.pop(0) self.wrap_times.pop(0)
self.cmd = subprocess.Popen( self.cmd = subprocess.Popen(
@ -116,7 +116,7 @@ Traffic Legend:
if self.ssl_target: if self.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
print(msg + "\n") self.logger.info(msg)
if self.wrap_cmd: if self.wrap_cmd:
self.run_wrap_cmd() self.run_wrap_cmd()
@ -142,7 +142,7 @@ Traffic Legend:
if (now - avg) < 10: if (now - avg) < 10:
# 3 times in the last 10 seconds # 3 times in the last 10 seconds
if self.spawn_message: if self.spawn_message:
print("Command respawning too fast") self.logger.warning("Command respawning too fast")
self.spawn_message = False self.spawn_message = False
else: else:
self.run_wrap_cmd() self.run_wrap_cmd()
@ -182,8 +182,7 @@ Traffic Legend:
tsock = self.socket(self.target_host, self.target_port, tsock = self.socket(self.target_host, self.target_port,
connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target)
if self.verbose and not self.daemon: self.logger.debug(self.traffic_legend)
print(self.traffic_legend)
# Start proxying # Start proxying
try: try:
@ -301,7 +300,19 @@ def _subprocess_setup():
signal.signal(signal.SIGPIPE, signal.SIG_DFL) signal.signal(signal.SIGPIPE, signal.SIG_DFL)
def logger_init():
logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.propagate = False
logger.setLevel(logging.INFO)
h = logging.StreamHandler()
h.setLevel(logging.DEBUG)
h.setFormatter(logging.Formatter("%(levelname)-7s %(message)s"))
logger.addHandler(h)
def websockify_init(): def websockify_init():
logger_init()
usage = "\n %prog [options]" usage = "\n %prog [options]"
usage += " [source_addr:]source_port [target_addr:target_port]" usage += " [source_addr:]source_port [target_addr:target_port]"
usage += "\n %prog [options]" usage += "\n %prog [options]"
@ -347,6 +358,9 @@ def websockify_init():
"directory containing configuration files of this form") "directory containing configuration files of this form")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
if opts.verbose:
logging.getLogger(WebSocketProxy.log_prefix).setLevel(logging.DEBUG)
# Sanity checks # Sanity checks
if len(args) < 2 and not (opts.target_cfg or opts.unix_target): if len(args) < 2 and not (opts.target_cfg or opts.unix_target):
parser.error("Too few arguments") parser.error("Too few arguments")