Merge pull request #194 from kanaka/feature/http-auth-plugins
Rework Auth Plugins to Support HTTP Auth
This commit is contained in:
commit
714aa34e4e
|
|
@ -106,11 +106,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
def lookup(self, token):
|
def lookup(self, token):
|
||||||
return (self.source + token).split(',')
|
return (self.source + token).split(',')
|
||||||
|
|
||||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
|
||||||
lambda *args, **kwargs: None)
|
staticmethod(lambda *args, **kwargs: None))
|
||||||
|
|
||||||
self.handler.server.token_plugin = TestPlugin("somehost,")
|
self.handler.server.token_plugin = TestPlugin("somehost,")
|
||||||
self.handler.new_websocket_client()
|
self.handler.validate_connection()
|
||||||
|
|
||||||
self.assertEqual(self.handler.server.target_host, "somehost")
|
self.assertEqual(self.handler.server.target_host, "somehost")
|
||||||
self.assertEqual(self.handler.server.target_port, "blah")
|
self.assertEqual(self.handler.server.target_port, "blah")
|
||||||
|
|
@ -119,9 +119,9 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
class TestPlugin(auth_plugins.BasePlugin):
|
class TestPlugin(auth_plugins.BasePlugin):
|
||||||
def authenticate(self, headers, target_host, target_port):
|
def authenticate(self, headers, target_host, target_port):
|
||||||
if target_host == self.source:
|
if target_host == self.source:
|
||||||
raise auth_plugins.AuthenticationError("some error")
|
raise auth_plugins.AuthenticationError(response_msg="some_error")
|
||||||
|
|
||||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
|
||||||
staticmethod(lambda *args, **kwargs: None))
|
staticmethod(lambda *args, **kwargs: None))
|
||||||
|
|
||||||
self.handler.server.auth_plugin = TestPlugin("somehost")
|
self.handler.server.auth_plugin = TestPlugin("somehost")
|
||||||
|
|
@ -129,8 +129,8 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
self.handler.server.target_port = "someport"
|
self.handler.server.target_port = "someport"
|
||||||
|
|
||||||
self.assertRaises(auth_plugins.AuthenticationError,
|
self.assertRaises(auth_plugins.AuthenticationError,
|
||||||
self.handler.new_websocket_client)
|
self.handler.validate_connection)
|
||||||
|
|
||||||
self.handler.server.target_host = "someotherhost"
|
self.handler.server.target_host = "someotherhost"
|
||||||
self.handler.new_websocket_client()
|
self.handler.validate_connection()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,15 @@ class BasePlugin(object):
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(Exception):
|
class AuthenticationError(Exception):
|
||||||
pass
|
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
|
||||||
|
self.code = response_code
|
||||||
|
self.headers = response_headers
|
||||||
|
self.msg = response_msg
|
||||||
|
|
||||||
|
if log_msg is None:
|
||||||
|
log_msg = response_msg
|
||||||
|
|
||||||
|
super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))
|
||||||
|
|
||||||
|
|
||||||
class InvalidOriginError(AuthenticationError):
|
class InvalidOriginError(AuthenticationError):
|
||||||
|
|
@ -16,9 +24,45 @@ class InvalidOriginError(AuthenticationError):
|
||||||
self.actual_origin = actual
|
self.actual_origin = actual
|
||||||
|
|
||||||
super(InvalidOriginError, self).__init__(
|
super(InvalidOriginError, self).__init__(
|
||||||
"Invalid Origin Header: Expected one of "
|
response_msg='Invalid Origin',
|
||||||
|
log_msg="Invalid Origin Header: Expected one of "
|
||||||
"%s, got '%s'" % (expected, actual))
|
"%s, got '%s'" % (expected, actual))
|
||||||
|
|
||||||
|
|
||||||
|
class BasicHTTPAuth(object):
|
||||||
|
def __init__(self, src=None):
|
||||||
|
self.src = src
|
||||||
|
|
||||||
|
def authenticate(self, headers, target_host, target_port):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
auth_header = headers.get('Authorization')
|
||||||
|
if auth_header:
|
||||||
|
if not auth_header.startswith('Basic '):
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_pass_raw = base64.b64decode(auth_header[6:])
|
||||||
|
except TypeError:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
user_pass = user_pass_raw.split(':', 1)
|
||||||
|
if len(user_pass) != 2:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
if not self.validate_creds:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AuthenticationError(response_code=401,
|
||||||
|
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
|
||||||
|
|
||||||
|
def validate_creds(username, password):
|
||||||
|
if '%s:%s' % (username, password) == self.src:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
class ExpectOrigin(object):
|
class ExpectOrigin(object):
|
||||||
def __init__(self, src=None):
|
def __init__(self, src=None):
|
||||||
if src is None:
|
if src is None:
|
||||||
|
|
|
||||||
|
|
@ -474,9 +474,13 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
|
||||||
"""Upgrade a connection to Websocket, if requested. If this succeeds,
|
"""Upgrade a connection to Websocket, if requested. If this succeeds,
|
||||||
new_websocket_client() will be called. Otherwise, False is returned.
|
new_websocket_client() will be called. Otherwise, False is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (self.headers.get('upgrade') and
|
if (self.headers.get('upgrade') and
|
||||||
self.headers.get('upgrade').lower() == 'websocket'):
|
self.headers.get('upgrade').lower() == 'websocket'):
|
||||||
|
|
||||||
|
# ensure connection is authorized, and determine the target
|
||||||
|
self.validate_connection()
|
||||||
|
|
||||||
if not self.do_websocket_handshake():
|
if not self.do_websocket_handshake():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -549,6 +553,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
|
||||||
""" Do something with a WebSockets client connection. """
|
""" Do something with a WebSockets client connection. """
|
||||||
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
|
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
|
||||||
|
|
||||||
|
def validate_connection(self):
|
||||||
|
""" Ensure that the connection is a valid connection, and set the target. """
|
||||||
|
pass
|
||||||
|
|
||||||
def do_HEAD(self):
|
def do_HEAD(self):
|
||||||
if self.only_upgrade:
|
if self.only_upgrade:
|
||||||
self.send_error(405, "Method Not Allowed")
|
self.send_error(405, "Method Not Allowed")
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ try: from http.server import HTTPServer
|
||||||
except: from BaseHTTPServer import HTTPServer
|
except: from BaseHTTPServer import HTTPServer
|
||||||
import select
|
import select
|
||||||
from websockify import websocket
|
from websockify import websocket
|
||||||
|
from websockify import auth_plugins as auth
|
||||||
try:
|
try:
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
except:
|
except:
|
||||||
|
|
@ -38,19 +39,33 @@ Traffic Legend:
|
||||||
<. - Client send partial
|
<. - Client send partial
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def new_websocket_client(self):
|
def send_auth_error(self, ex):
|
||||||
"""
|
self.send_response(ex.code, ex.msg)
|
||||||
Called after a new WebSocket connection has been established.
|
self.send_header('Content-Type', 'text/html')
|
||||||
"""
|
for name, val in ex.headers.items():
|
||||||
# Checks if we receive a token, and look
|
self.send_header(name, val)
|
||||||
# for a valid target for it then
|
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def validate_connection(self):
|
||||||
if self.server.token_plugin:
|
if self.server.token_plugin:
|
||||||
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
|
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
|
||||||
|
|
||||||
if self.server.auth_plugin:
|
if self.server.auth_plugin:
|
||||||
|
try:
|
||||||
self.server.auth_plugin.authenticate(
|
self.server.auth_plugin.authenticate(
|
||||||
headers=self.headers, target_host=self.server.target_host,
|
headers=self.headers, target_host=self.server.target_host,
|
||||||
target_port=self.server.target_port)
|
target_port=self.server.target_port)
|
||||||
|
except auth.AuthenticationError:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
self.send_auth_error(ex)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def new_websocket_client(self):
|
||||||
|
"""
|
||||||
|
Called after a new WebSocket connection has been established.
|
||||||
|
"""
|
||||||
|
# Checking for a token is done in validate_connection()
|
||||||
|
|
||||||
# Connect to the target
|
# Connect to the target
|
||||||
if self.server.wrap_cmd:
|
if self.server.wrap_cmd:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue