This commit is contained in:
Antoine Martin 2017-01-20 18:01:15 +00:00 committed by GitHub
commit 4ad481c682
3 changed files with 74 additions and 60 deletions

View File

@ -411,9 +411,11 @@ class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_encode_hybi_basic(self): def test_encode_hybi_basic(self):
res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1) res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1)
expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0) expected_header = (b'\x81\x05')
expected_data = (b'\x48\x65\x6c\x6c\x6f')
self.assertEqual(res, expected) self.assertEqual(res[0], expected_header)
self.assertEqual(res[1], expected_data)
def test_strict_mode_refuses_unmasked_client_frames(self): def test_strict_mode_refuses_unmasked_client_frames(self):
buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' buf = b'\x81\x05\x48\x65\x6c\x6c\x6f'

View File

@ -50,10 +50,15 @@ except:
return struct.unpack(fmt, slice) return struct.unpack(fmt, slice)
# Degraded functionality if these imports are missing # Degraded functionality if these imports are missing
for mod, msg in [('numpy', 'HyBi protocol will be slower'), OPTIONAL_FEATURES = [
('numpy', 'HyBi protocol will be slower'),
('ssl', 'TLS/SSL/wss is disabled'), ('ssl', 'TLS/SSL/wss is disabled'),
('multiprocessing', 'Multi-Processing is disabled'), ('multiprocessing', 'Multi-Processing is disabled'),
('resource', 'daemonizing is disabled')]: ]
if not sys.platform.startswith("win"):
#resource module is not available on Windows
OPTIONAL_FEATURES.append(('resource', 'daemonizing is disabled'))
for mod, msg in OPTIONAL_FEATURES:
try: try:
globals()[mod] = __import__(mod) globals()[mod] = __import__(mod)
except ImportError: except ImportError:
@ -154,8 +159,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
return data.tostring() return data.tostring()
@staticmethod @staticmethod
def encode_hybi(buf, opcode, base64=False): def encode_hybi_header(payload_len, opcode):
""" Encode a HyBi style WebSocket frame. """ Encode a HyBi style WebSocket frame header.
Optional opcode: Optional opcode:
0x0 - continuation 0x0 - continuation
0x1 - text frame (base64 encode buf) 0x1 - text frame (base64 encode buf)
@ -164,21 +169,34 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
0x9 - ping 0x9 - ping
0xA - pong 0xA - pong
""" """
b1 = 0x80 | (opcode & 0x0f) # FIN + opcode
if payload_len <= 125:
return pack('>BB', b1, payload_len)
elif payload_len > 125 and payload_len < 65536:
return pack('>BBH', b1, 126, payload_len)
else:
assert payload_len >= 65536
return pack('>BBQ', b1, 127, payload_len)
@staticmethod
def encode_hybi(buf, opcode, base64=False):
""" Encode a HyBi style WebSocket frame. """
if base64: if base64:
buf = b64encode(buf) buf = b64encode(buf)
header = WebSocketRequestHandler.encode_hybi_header(len(buf), opcode)
return (header, buf)
b1 = 0x80 | (opcode & 0x0f) # FIN + opcode def send_hybi(self, buf, opcode, base64=False, record=None):
payload_len = len(buf) """ Send a HyBi style WebSocket frame. """
if payload_len <= 125: header, buf = self.encode_hybi(buf, opcode, base64)
header = pack('>BB', b1, payload_len) if record:
elif payload_len > 125 and payload_len < 65536: record.write("%s,\n" % repr("{%s{" % tdelta + encbufs[1]))
header = pack('>BBH', b1, 126, payload_len) if len(buf)<=4096:
elif payload_len >= 65536: self.request.send(header+buf)
header = pack('>BBQ', b1, 127, payload_len) else:
self.request.send(header)
self.request.send(buf)
#self.msg("Encoded: %s", repr(header + buf))
return header + buf, len(header), 0
@staticmethod @staticmethod
def decode_hybi(buf, base64=False, logger=None, strict=True): def decode_hybi(buf, base64=False, logger=None, strict=True):
@ -311,30 +329,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
if bufs: if bufs:
for buf in bufs: for buf in bufs:
if self.base64: if self.base64:
encbuf, lenhead, lentail = self.encode_hybi(buf, opcode=1, base64=True) opcode = 1
else: else:
encbuf, lenhead, lentail = self.encode_hybi(buf, opcode=2, base64=False) opcode = 2
self.send_hybi(buf, opcode, base64=self.base64, record=self.rec)
if self.rec:
self.rec.write("%s,\n" %
repr("{%s{" % tdelta
+ encbuf[lenhead:len(encbuf)-lentail]))
self.send_parts.append(encbuf)
while self.send_parts:
# Send pending frames
buf = self.send_parts.pop(0)
sent = self.request.send(buf)
if sent == len(buf):
self.print_traffic("<")
else:
self.print_traffic("<.")
self.send_parts.insert(0, buf[sent:])
break
return len(self.send_parts)
def recv_frames(self): def recv_frames(self):
""" Receive and decode WebSocket frames. """ Receive and decode WebSocket frames.
@ -412,18 +410,15 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
""" Send a WebSocket orderly close frame. """ """ Send a WebSocket orderly close frame. """
msg = pack(">H%ds" % len(reason), code, s2b(reason)) msg = pack(">H%ds" % len(reason), code, s2b(reason))
buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) self.send_hybi(msg, opcode=0x08)
self.request.send(buf)
def send_pong(self, data=''): def send_pong(self, data=''):
""" Send a WebSocket pong frame. """ """ Send a WebSocket pong frame. """
buf, h, t = self.encode_hybi(s2b(data), opcode=0x0A, base64=False) self.send_hybi(s2b(data), opcode=0x0A)
self.request.send(buf)
def send_ping(self, data=''): def send_ping(self, data=''):
""" Send a WebSocket ping frame. """ """ Send a WebSocket ping frame. """
buf, h, t = self.encode_hybi(s2b(data), opcode=0x09, base64=False) self.send_hybi(s2b(data), opcode=0x09)
self.request.send(buf)
def do_websocket_handshake(self): def do_websocket_handshake(self):
h = self.headers h = self.headers
@ -494,7 +489,6 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
# Indicate to server that a Websocket upgrade was done # Indicate to server that a Websocket upgrade was done
self.server.ws_connection = True self.server.ws_connection = True
# Initialize per client settings # Initialize per client settings
self.send_parts = []
self.recv_part = None self.recv_part = None
self.start_time = int(time.time()*1000) self.start_time = int(time.time()*1000)
@ -550,6 +544,20 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
else: else:
SimpleHTTPRequestHandler.do_GET(self) SimpleHTTPRequestHandler.do_GET(self)
def copyfile(self, source, outputfile):
"""Adds retry code for WSAEWOULDBLOCK on MS Windows"""
if not sys.platform.startswith("win"):
return SimpleHTTPRequestHandler.copyfile(self, source, outputfile)
import shutil
while True:
try:
shutil.copyfileobj(source, outputfile)
return
except (IOError, OSError) as e:
if e[0]==errno.WSAEWOULDBLOCK:
continue
raise
def list_directory(self, path): def list_directory(self, path):
if self.file_only: if self.file_only:
self.send_error(404, "No such file") self.send_error(404, "No such file")
@ -990,17 +998,20 @@ class WebSocketServer(object):
original_signals = { original_signals = {
signal.SIGINT: signal.getsignal(signal.SIGINT), signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM), signal.SIGTERM: signal.getsignal(signal.SIGTERM),
signal.SIGCHLD: signal.getsignal(signal.SIGCHLD),
} }
signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM) signal.signal(signal.SIGTERM, self.do_SIGTERM)
#SIGCHLD is only available on posix:
SIGCHLD = getattr(signal, "SIGCHLD", None)
if SIGCHLD:
original_signals[SIGCHLD] = signal.getsignal(SIGCHLD)
if not multiprocessing: if not multiprocessing:
# os.fork() (python 2.4) child reaper # os.fork() (python 2.4) child reaper
signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD) signal.signal(SIGCHLD, self.fallback_SIGCHLD)
else: else:
# make sure that _cleanup is called when children die # make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD # by calling active_children on SIGCHLD
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) signal.signal(SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time last_active_time = self.launch_time
try: try:

View File

@ -137,7 +137,6 @@ Traffic Legend:
Proxy client WebSocket to normal target socket. Proxy client WebSocket to normal target socket.
""" """
cqueue = [] cqueue = []
c_pend = 0
tqueue = [] tqueue = []
rlist = [self.request, target] rlist = [self.request, target]
@ -157,7 +156,7 @@ Traffic Legend:
self.send_ping() self.send_ping()
if tqueue: wlist.append(target) if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.request) if cqueue: wlist.append(self.request)
try: try:
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
except (select.error, OSError): except (select.error, OSError):
@ -176,7 +175,7 @@ Traffic Legend:
if self.request in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
c_pend = self.send_frames(cqueue) self.send_frames(cqueue)
cqueue = [] cqueue = []
@ -229,6 +228,8 @@ class WebSocketProxy(websocket.WebSocketServer):
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options # Save off proxy specific options
if sys.platform.startswith("win"):
kwargs.pop("multiprocessing_fork", None)
self.target_host = kwargs.pop('target_host', None) self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None) self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None) self.wrap_cmd = kwargs.pop('wrap_cmd', None)