Rework Auth Plugins to Support HTTP Auth
This commit reworks auth plugins slightly to enable support for HTTP authentication. By raising an AuthenticationError, auth plugins can now return HTTP responses to the upgrade request (such as 401). Related to kanaka/noVNC#522
This commit is contained in:
parent
6c1543c05b
commit
1e2b5c2256
|
|
@ -106,11 +106,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
def lookup(self, token):
|
def lookup(self, token):
|
||||||
return (self.source + token).split(',')
|
return (self.source + token).split(',')
|
||||||
|
|
||||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
|
||||||
lambda *args, **kwargs: None)
|
staticmethod(lambda *args, **kwargs: None))
|
||||||
|
|
||||||
self.handler.server.token_plugin = TestPlugin("somehost,")
|
self.handler.server.token_plugin = TestPlugin("somehost,")
|
||||||
self.handler.new_websocket_client()
|
self.handler.validate_connection()
|
||||||
|
|
||||||
self.assertEqual(self.handler.server.target_host, "somehost")
|
self.assertEqual(self.handler.server.target_host, "somehost")
|
||||||
self.assertEqual(self.handler.server.target_port, "blah")
|
self.assertEqual(self.handler.server.target_port, "blah")
|
||||||
|
|
@ -119,9 +119,9 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
class TestPlugin(auth_plugins.BasePlugin):
|
class TestPlugin(auth_plugins.BasePlugin):
|
||||||
def authenticate(self, headers, target_host, target_port):
|
def authenticate(self, headers, target_host, target_port):
|
||||||
if target_host == self.source:
|
if target_host == self.source:
|
||||||
raise auth_plugins.AuthenticationError("some error")
|
raise auth_plugins.AuthenticationError(response_msg="some_error")
|
||||||
|
|
||||||
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
|
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
|
||||||
staticmethod(lambda *args, **kwargs: None))
|
staticmethod(lambda *args, **kwargs: None))
|
||||||
|
|
||||||
self.handler.server.auth_plugin = TestPlugin("somehost")
|
self.handler.server.auth_plugin = TestPlugin("somehost")
|
||||||
|
|
@ -129,8 +129,8 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
|
||||||
self.handler.server.target_port = "someport"
|
self.handler.server.target_port = "someport"
|
||||||
|
|
||||||
self.assertRaises(auth_plugins.AuthenticationError,
|
self.assertRaises(auth_plugins.AuthenticationError,
|
||||||
self.handler.new_websocket_client)
|
self.handler.validate_connection)
|
||||||
|
|
||||||
self.handler.server.target_host = "someotherhost"
|
self.handler.server.target_host = "someotherhost"
|
||||||
self.handler.new_websocket_client()
|
self.handler.validate_connection()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,15 @@ class BasePlugin(object):
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(Exception):
|
class AuthenticationError(Exception):
|
||||||
pass
|
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
|
||||||
|
self.code = response_code
|
||||||
|
self.headers = response_headers
|
||||||
|
self.msg = response_msg
|
||||||
|
|
||||||
|
if log_msg is None:
|
||||||
|
log_msg = response_msg
|
||||||
|
|
||||||
|
super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))
|
||||||
|
|
||||||
|
|
||||||
class InvalidOriginError(AuthenticationError):
|
class InvalidOriginError(AuthenticationError):
|
||||||
|
|
@ -16,9 +24,45 @@ class InvalidOriginError(AuthenticationError):
|
||||||
self.actual_origin = actual
|
self.actual_origin = actual
|
||||||
|
|
||||||
super(InvalidOriginError, self).__init__(
|
super(InvalidOriginError, self).__init__(
|
||||||
"Invalid Origin Header: Expected one of "
|
response_msg='Invalid Origin',
|
||||||
|
log_msg="Invalid Origin Header: Expected one of "
|
||||||
"%s, got '%s'" % (expected, actual))
|
"%s, got '%s'" % (expected, actual))
|
||||||
|
|
||||||
|
|
||||||
|
class BasicHTTPAuth(object):
|
||||||
|
def __init__(self, src=None):
|
||||||
|
self.src = src
|
||||||
|
|
||||||
|
def authenticate(self, headers, target_host, target_port):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
auth_header = headers.get('Authorization')
|
||||||
|
if auth_header:
|
||||||
|
if not auth_header.startswith('Basic '):
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_pass_raw = base64.b64decode(auth_header[6:])
|
||||||
|
except TypeError:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
user_pass = user_pass_raw.split(':', 1)
|
||||||
|
if len(user_pass) != 2:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
if not self.validate_creds:
|
||||||
|
raise AuthenticationError(response_code=403)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AuthenticationError(response_code=401,
|
||||||
|
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
|
||||||
|
|
||||||
|
def validate_creds(username, password):
|
||||||
|
if '%s:%s' % (username, password) == self.src:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
class ExpectOrigin(object):
|
class ExpectOrigin(object):
|
||||||
def __init__(self, src=None):
|
def __init__(self, src=None):
|
||||||
if src is None:
|
if src is None:
|
||||||
|
|
|
||||||
|
|
@ -474,9 +474,13 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
|
||||||
"""Upgrade a connection to Websocket, if requested. If this succeeds,
|
"""Upgrade a connection to Websocket, if requested. If this succeeds,
|
||||||
new_websocket_client() will be called. Otherwise, False is returned.
|
new_websocket_client() will be called. Otherwise, False is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (self.headers.get('upgrade') and
|
if (self.headers.get('upgrade') and
|
||||||
self.headers.get('upgrade').lower() == 'websocket'):
|
self.headers.get('upgrade').lower() == 'websocket'):
|
||||||
|
|
||||||
|
# ensure connection is authorized, and determine the target
|
||||||
|
self.validate_connection()
|
||||||
|
|
||||||
if not self.do_websocket_handshake():
|
if not self.do_websocket_handshake():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -549,6 +553,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
|
||||||
""" Do something with a WebSockets client connection. """
|
""" Do something with a WebSockets client connection. """
|
||||||
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
|
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
|
||||||
|
|
||||||
|
def validate_connection(self):
|
||||||
|
""" Ensure that the connection is a valid connection, and set the target. """
|
||||||
|
pass
|
||||||
|
|
||||||
def do_HEAD(self):
|
def do_HEAD(self):
|
||||||
if self.only_upgrade:
|
if self.only_upgrade:
|
||||||
self.send_error(405, "Method Not Allowed")
|
self.send_error(405, "Method Not Allowed")
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ try: from http.server import HTTPServer
|
||||||
except: from BaseHTTPServer import HTTPServer
|
except: from BaseHTTPServer import HTTPServer
|
||||||
import select
|
import select
|
||||||
from websockify import websocket
|
from websockify import websocket
|
||||||
|
from websockify import auth_plugins as auth
|
||||||
try:
|
try:
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
except:
|
except:
|
||||||
|
|
@ -38,19 +39,33 @@ Traffic Legend:
|
||||||
<. - Client send partial
|
<. - Client send partial
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def new_websocket_client(self):
|
def send_auth_error(self, ex):
|
||||||
"""
|
self.send_response(ex.code, ex.msg)
|
||||||
Called after a new WebSocket connection has been established.
|
self.send_header('Content-Type', 'text/html')
|
||||||
"""
|
for name, val in ex.headers.items():
|
||||||
# Checks if we receive a token, and look
|
self.send_header(name, val)
|
||||||
# for a valid target for it then
|
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def validate_connection(self):
|
||||||
if self.server.token_plugin:
|
if self.server.token_plugin:
|
||||||
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
|
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
|
||||||
|
|
||||||
if self.server.auth_plugin:
|
if self.server.auth_plugin:
|
||||||
|
try:
|
||||||
self.server.auth_plugin.authenticate(
|
self.server.auth_plugin.authenticate(
|
||||||
headers=self.headers, target_host=self.server.target_host,
|
headers=self.headers, target_host=self.server.target_host,
|
||||||
target_port=self.server.target_port)
|
target_port=self.server.target_port)
|
||||||
|
except auth.AuthenticationError:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
self.send_auth_error(ex)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def new_websocket_client(self):
|
||||||
|
"""
|
||||||
|
Called after a new WebSocket connection has been established.
|
||||||
|
"""
|
||||||
|
# Checking for a token is done in validate_connection()
|
||||||
|
|
||||||
# Connect to the target
|
# Connect to the target
|
||||||
if self.server.wrap_cmd:
|
if self.server.wrap_cmd:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue