"""
Contains OpenID provider functionality
"""
import logging, os
from galaxy.util import parse_xml, string_as_bool
from galaxy.util.odict import odict
log = logging.getLogger( __name__ )
NO_PROVIDER_ID = 'None'
RESERVED_PROVIDER_IDS = [ NO_PROVIDER_ID ]
[docs]class OpenIDProvider( object ):
'''An OpenID Provider object.'''
@classmethod
[docs] def from_file( cls, filename ):
return cls.from_elem( parse_xml( filename ).getroot() )
@classmethod
[docs] def from_elem( cls, xml_root ):
provider_elem = xml_root
provider_id = provider_elem.get( 'id', None )
provider_name = provider_elem.get( 'name', provider_id )
op_endpoint_url = provider_elem.find( 'op_endpoint_url' )
if op_endpoint_url is not None:
op_endpoint_url = op_endpoint_url.text
never_associate_with_user = string_as_bool( provider_elem.get( 'never_associate_with_user', 'False' ) )
assert (provider_id and provider_name and op_endpoint_url), Exception( "OpenID Provider improperly configured" )
assert provider_id not in RESERVED_PROVIDER_IDS, Exception( 'Specified OpenID Provider uses a reserved id: %s' % ( provider_id ) )
sreg_required = []
sreg_optional = []
use_for = {}
store_user_preference = {}
use_default_sreg = True
for elem in provider_elem.findall( 'sreg' ):
use_default_sreg = False
for field_elem in elem.findall( 'field' ):
sreg_name = field_elem.get( 'name' )
assert sreg_name, Exception( 'A name is required for a sreg element' )
if string_as_bool( field_elem.get( 'required' ) ):
sreg_required.append( sreg_name )
else:
sreg_optional.append( sreg_name )
for use_elem in field_elem.findall( 'use_for' ):
use_for[ use_elem.get( 'name' ) ] = sreg_name
for store_user_preference_elem in field_elem.findall( 'store_user_preference' ):
store_user_preference[ store_user_preference_elem.get( 'name' ) ] = sreg_name
if use_default_sreg:
sreg_required = None
sreg_optional = None
use_for = None
return cls( provider_id, provider_name, op_endpoint_url, sreg_required=sreg_required, sreg_optional=sreg_optional, use_for=use_for, store_user_preference=store_user_preference, never_associate_with_user=never_associate_with_user )
def __init__( self, id, name, op_endpoint_url, sreg_required=None, sreg_optional=None, use_for=None, store_user_preference=None, never_associate_with_user=None ):
'''When sreg options are not specified, defaults are used.'''
self.id = id
self.name = name
self.op_endpoint_url = op_endpoint_url
if sreg_optional is None:
self.sreg_optional = [ 'nickname', 'email' ]
else:
self.sreg_optional = sreg_optional
if sreg_required:
self.sreg_required = sreg_required
else:
self.sreg_required = []
if use_for is not None:
self.use_for = use_for
else:
self.use_for = {}
if 'nickname' in ( self.sreg_optional + self.sreg_required ):
self.use_for[ 'username' ] = 'nickname'
if 'email' in ( self.sreg_optional + self.sreg_required ):
self.use_for[ 'email' ] = 'email'
if store_user_preference:
self.store_user_preference = store_user_preference
else:
self.store_user_preference = {}
if never_associate_with_user:
self.never_associate_with_user = True
else:
self.never_associate_with_user = False
[docs] def post_authentication( self, trans, openid_manager, info ):
sreg_attributes = openid_manager.get_sreg( info )
for store_pref_name, store_pref_value_name in self.store_user_preference.iteritems():
if store_pref_value_name in ( self.sreg_optional + self.sreg_required ):
trans.user.preferences[ store_pref_name ] = sreg_attributes.get( store_pref_value_name )
else:
raise Exception( 'Only sreg is currently supported.' )
trans.sa_session.add( trans.user )
trans.sa_session.flush()
[docs] def has_post_authentication_actions( self ):
return bool( self.store_user_preference )
[docs]class OpenIDProviders( object ):
'''Collection of OpenID Providers'''
NO_PROVIDER_ID = NO_PROVIDER_ID
@classmethod
[docs] def from_file( cls, filename ):
try:
return cls.from_elem( parse_xml( filename ).getroot() )
except Exception, e:
log.error( 'Failed to load OpenID Providers: %s' % ( e ) )
return cls()
@classmethod
[docs] def from_elem( cls, xml_root ):
oid_elem = xml_root
providers = odict()
for elem in oid_elem.findall( 'provider' ):
try:
provider = OpenIDProvider.from_file( os.path.join( 'openid', elem.get( 'file' ) ) )
providers[ provider.id ] = provider
log.debug( 'Loaded OpenID provider: %s (%s)' % ( provider.name, provider.id ) )
except Exception, e:
log.error( 'Failed to add OpenID provider: %s' % ( e ) )
return cls( providers )
def __init__( self, providers=None ):
if providers:
self.providers = providers
else:
self.providers = odict()
self._banned_identifiers = [ provider.op_endpoint_url for provider in self.providers.itervalues() if provider.never_associate_with_user ]
def __iter__( self ):
for provider in self.providers.itervalues():
yield provider
[docs] def get( self, name, default=None ):
if name in self.providers:
return self.providers[ name ]
else:
return default
[docs] def new_provider_from_identifier( self, identifier ):
return OpenIDProvider( None, identifier, identifier, never_associate_with_user = identifier in self._banned_identifiers )