diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 545fa1c..22802b0 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -411,9 +411,11 @@ class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_encode_hybi_basic(self): res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1) - expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0) + expected_header = (b'\x81\x05') + expected_data = (b'\x48\x65\x6c\x6c\x6f') - self.assertEqual(res, expected) + self.assertEqual(res[0], expected_header) + self.assertEqual(res[1], expected_data) def test_strict_mode_refuses_unmasked_client_frames(self): buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' diff --git a/websockify/websocket.py b/websockify/websocket.py index ebb3a53..16d39b7 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -50,10 +50,15 @@ except: return struct.unpack(fmt, slice) # Degraded functionality if these imports are missing -for mod, msg in [('numpy', 'HyBi protocol will be slower'), - ('ssl', 'TLS/SSL/wss is disabled'), - ('multiprocessing', 'Multi-Processing is disabled'), - ('resource', 'daemonizing is disabled')]: +OPTIONAL_FEATURES = [ + ('numpy', 'HyBi protocol will be slower'), + ('ssl', 'TLS/SSL/wss is disabled'), + ('multiprocessing', 'Multi-Processing is disabled'), + ] +if not sys.platform.startswith("win"): + #resource module is not available on Windows + OPTIONAL_FEATURES.append(('resource', 'daemonizing is disabled')) +for mod, msg in OPTIONAL_FEATURES: try: globals()[mod] = __import__(mod) except ImportError: @@ -154,8 +159,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): return data.tostring() @staticmethod - def encode_hybi(buf, opcode, base64=False): - """ Encode a HyBi style WebSocket frame. + def encode_hybi_header(payload_len, opcode): + """ Encode a HyBi style WebSocket frame header. Optional opcode: 0x0 - continuation 0x1 - text frame (base64 encode buf) @@ -164,21 +169,34 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): 0x9 - ping 0xA - pong """ + b1 = 0x80 | (opcode & 0x0f) # FIN + opcode + if payload_len <= 125: + return pack('>BB', b1, payload_len) + elif payload_len > 125 and payload_len < 65536: + return pack('>BBH', b1, 126, payload_len) + else: + assert payload_len >= 65536 + return pack('>BBQ', b1, 127, payload_len) + + @staticmethod + def encode_hybi(buf, opcode, base64=False): + """ Encode a HyBi style WebSocket frame. """ if base64: buf = b64encode(buf) + header = WebSocketRequestHandler.encode_hybi_header(len(buf), opcode) + return (header, buf) - b1 = 0x80 | (opcode & 0x0f) # FIN + opcode - payload_len = len(buf) - if payload_len <= 125: - header = pack('>BB', b1, payload_len) - elif payload_len > 125 and payload_len < 65536: - header = pack('>BBH', b1, 126, payload_len) - elif payload_len >= 65536: - header = pack('>BBQ', b1, 127, payload_len) + def send_hybi(self, buf, opcode, base64=False, record=None): + """ Send a HyBi style WebSocket frame. """ + header, buf = self.encode_hybi(buf, opcode, base64) + if record: + record.write("%s,\n" % repr("{%s{" % tdelta + encbufs[1])) + if len(buf)<=4096: + self.request.send(header+buf) + else: + self.request.send(header) + self.request.send(buf) - #self.msg("Encoded: %s", repr(header + buf)) - - return header + buf, len(header), 0 @staticmethod def decode_hybi(buf, base64=False, logger=None, strict=True): @@ -311,30 +329,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): if bufs: for buf in bufs: if self.base64: - encbuf, lenhead, lentail = self.encode_hybi(buf, opcode=1, base64=True) + opcode = 1 else: - encbuf, lenhead, lentail = self.encode_hybi(buf, opcode=2, base64=False) - - if self.rec: - self.rec.write("%s,\n" % - repr("{%s{" % tdelta - + encbuf[lenhead:len(encbuf)-lentail])) - - self.send_parts.append(encbuf) - - while self.send_parts: - # Send pending frames - buf = self.send_parts.pop(0) - sent = self.request.send(buf) - - if sent == len(buf): - self.print_traffic("<") - else: - self.print_traffic("<.") - self.send_parts.insert(0, buf[sent:]) - break - - return len(self.send_parts) + opcode = 2 + self.send_hybi(buf, opcode, base64=self.base64, record=self.rec) def recv_frames(self): """ Receive and decode WebSocket frames. @@ -412,18 +410,15 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): """ Send a WebSocket orderly close frame. """ msg = pack(">H%ds" % len(reason), code, s2b(reason)) - buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) - self.request.send(buf) + self.send_hybi(msg, opcode=0x08) def send_pong(self, data=''): """ Send a WebSocket pong frame. """ - buf, h, t = self.encode_hybi(s2b(data), opcode=0x0A, base64=False) - self.request.send(buf) + self.send_hybi(s2b(data), opcode=0x0A) def send_ping(self, data=''): """ Send a WebSocket ping frame. """ - buf, h, t = self.encode_hybi(s2b(data), opcode=0x09, base64=False) - self.request.send(buf) + self.send_hybi(s2b(data), opcode=0x09) def do_websocket_handshake(self): h = self.headers @@ -494,7 +489,6 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): # Indicate to server that a Websocket upgrade was done self.server.ws_connection = True # Initialize per client settings - self.send_parts = [] self.recv_part = None self.start_time = int(time.time()*1000) @@ -550,6 +544,20 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): else: SimpleHTTPRequestHandler.do_GET(self) + def copyfile(self, source, outputfile): + """Adds retry code for WSAEWOULDBLOCK on MS Windows""" + if not sys.platform.startswith("win"): + return SimpleHTTPRequestHandler.copyfile(self, source, outputfile) + import shutil + while True: + try: + shutil.copyfileobj(source, outputfile) + return + except (IOError, OSError) as e: + if e[0]==errno.WSAEWOULDBLOCK: + continue + raise + def list_directory(self, path): if self.file_only: self.send_error(404, "No such file") @@ -990,17 +998,20 @@ class WebSocketServer(object): original_signals = { signal.SIGINT: signal.getsignal(signal.SIGINT), signal.SIGTERM: signal.getsignal(signal.SIGTERM), - signal.SIGCHLD: signal.getsignal(signal.SIGCHLD), } signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGTERM, self.do_SIGTERM) - if not multiprocessing: - # os.fork() (python 2.4) child reaper - signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD) - else: - # make sure that _cleanup is called when children die - # by calling active_children on SIGCHLD - signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) + #SIGCHLD is only available on posix: + SIGCHLD = getattr(signal, "SIGCHLD", None) + if SIGCHLD: + original_signals[SIGCHLD] = signal.getsignal(SIGCHLD) + if not multiprocessing: + # os.fork() (python 2.4) child reaper + signal.signal(SIGCHLD, self.fallback_SIGCHLD) + else: + # make sure that _cleanup is called when children die + # by calling active_children on SIGCHLD + signal.signal(SIGCHLD, self.multiprocessing_SIGCHLD) last_active_time = self.launch_time try: diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 837d801..14b0e06 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -137,7 +137,6 @@ Traffic Legend: Proxy client WebSocket to normal target socket. """ cqueue = [] - c_pend = 0 tqueue = [] rlist = [self.request, target] @@ -157,7 +156,7 @@ Traffic Legend: self.send_ping() if tqueue: wlist.append(target) - if cqueue or c_pend: wlist.append(self.request) + if cqueue: wlist.append(self.request) try: ins, outs, excepts = select.select(rlist, wlist, [], 1) except (select.error, OSError): @@ -176,7 +175,7 @@ Traffic Legend: if self.request in outs: # Send queued target data to the client - c_pend = self.send_frames(cqueue) + self.send_frames(cqueue) cqueue = [] @@ -229,6 +228,8 @@ class WebSocketProxy(websocket.WebSocketServer): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): # Save off proxy specific options + if sys.platform.startswith("win"): + kwargs.pop("multiprocessing_fork", None) self.target_host = kwargs.pop('target_host', None) self.target_port = kwargs.pop('target_port', None) self.wrap_cmd = kwargs.pop('wrap_cmd', None)