160 lines
4.9 KiB
Cython
160 lines
4.9 KiB
Cython
from collections import OrderedDict
|
|
|
|
cdef class Aggregate:
|
|
def __init__(Aggregate self, *modules):
|
|
self._attrs = {} # attr_name, module
|
|
self._modules = OrderedDict() # module_name, module
|
|
self._provided = {} # keyword, module
|
|
self._common = {} # method_name, method_wrapper
|
|
|
|
for module in modules:
|
|
self._link_module(module)
|
|
|
|
def _list_modules(Aggregate self):
|
|
d = {}
|
|
|
|
for attr, module in self._attrs.items():
|
|
if module not in d:
|
|
d[module] = set()
|
|
|
|
d[module].add(attr)
|
|
|
|
return d
|
|
|
|
def _link_module(Aggregate self, Module module):
|
|
# check for name (=type) collision
|
|
if module._name in self._modules.keys():
|
|
raise ModuleCollision(module._name, None, 'name', None)
|
|
|
|
# check if requirements are satisfied
|
|
unsatisfied_deps = module._requires - set(self._provided.keys())
|
|
if unsatisfied_deps:
|
|
raise ModuleDependencyError(module._name, unsatisfied_deps)
|
|
|
|
# check for new module declaring a common method that we already provide as non-common
|
|
new_commons = module._common - set(self._common.keys())
|
|
common_collisions = {nc for nc in new_commons if nc in self._attrs and not nc in self._common.keys()}
|
|
if common_collisions:
|
|
colliding_module_names = {self._attrs[x]._name for x in common_collisions}
|
|
raise ModuleCollision(module._name, colliding_module_names, 'non-common method', common_collisions)
|
|
|
|
# check for an attr collision
|
|
module_attrs = {x for x in dir(module) if x[0] != '_'}
|
|
attr_collisions = (module_attrs - module._common) & (set(self._attrs.keys()) | set(self._common.keys()))
|
|
if attr_collisions:
|
|
colliding_module_names = set()
|
|
for collision in attr_collisions:
|
|
if collision in self._attrs:
|
|
colliding_module_names.add(self._attrs[collision]._name)
|
|
|
|
if collision in self._common:
|
|
colliding_module_names.add(self._common[collision]._name)
|
|
|
|
raise ModuleCollision(module._name, colliding_module_names, 'attribute', attr_collisions)
|
|
|
|
# check for a provided keyword collision
|
|
provided_collisions = module._provides & set(self._provided.keys())
|
|
if provided_collisions:
|
|
colliding_module_names = {self._provided[x]._name for x in provided_collisions}
|
|
raise ModuleCollision(module._name, colliding_module_names, 'provided keyword', provided_collisions)
|
|
|
|
# link the module
|
|
self._modules[module._name] = module
|
|
|
|
for keyword in module._provides:
|
|
self._provided[keyword] = module
|
|
|
|
for attr in (module_attrs - module._common):
|
|
self._attrs[attr] = module
|
|
|
|
# create and/or populate CommonMethod wrappers to common methods
|
|
for method_name in module._common:
|
|
if method_name not in module_attrs:
|
|
raise CommonMethodMissing(method_name, module._name)
|
|
|
|
if method_name not in self._common:
|
|
self._common[method_name] = CommonMethod(method_name)
|
|
|
|
self._common[method_name].link_module(module)
|
|
|
|
# hand the module a reference to us
|
|
module._top = self
|
|
|
|
# call the module's _on_link method, if it has one
|
|
if hasattr(module, '_on_link'):
|
|
module._on_link()
|
|
|
|
def _unlink_module(Aggregate self, str module_name):
|
|
if not module_name in self._modules:
|
|
raise ModuleDoesntExist(module_name)
|
|
|
|
module = self._modules[module_name]
|
|
|
|
# check reverse dependencies
|
|
global_deps = set()
|
|
for m in self._modules.values():
|
|
global_deps.update(m._requires)
|
|
|
|
reverse_deps = module._provides & global_deps
|
|
if reverse_deps:
|
|
raise ModuleDependencyError(module_name, reverse_deps, unlink=True)
|
|
|
|
# remove from all pools
|
|
for aname, mod in list(self._attrs.items()):
|
|
if mod._name == module_name:
|
|
del self._attrs[aname]
|
|
|
|
del self._modules[module_name]
|
|
|
|
for ename in module._provides:
|
|
del self._provided[ename]
|
|
|
|
# remove _common wrappers
|
|
for method_name in module._common:
|
|
self._common[method_name].unlink_module(module_name)
|
|
|
|
# clear _top reference
|
|
module._top = None
|
|
|
|
def _merge_in(Aggregate self, Aggregate other_ag):
|
|
for module_name, module in other_ag._modules.items():
|
|
if module_name not in self._modules:
|
|
self._link_module(module)
|
|
|
|
def __getattr__(Aggregate self, str aname):
|
|
if aname in self._attrs:
|
|
return getattr(self._attrs[aname], aname)
|
|
elif aname in self._common:
|
|
return self._common[aname]
|
|
else:
|
|
raise AttributeError("Aggregate has no attribute '%s'" %(aname))
|
|
|
|
def __setattr__(Aggregate self, str aname, avalue):
|
|
if aname not in self._attrs:
|
|
raise AttributeError("Aggregate has no attribute '%s'" %(aname))
|
|
else:
|
|
setattr(self._attrs[aname], aname, avalue)
|
|
|
|
def _get_type(Aggregate self):
|
|
return tuple(self._modules.keys())
|
|
|
|
def __repr__(Aggregate self):
|
|
module_count = len(self._modules)
|
|
|
|
if module_count:
|
|
lines = ['Aggregate(']
|
|
|
|
|
|
for i, module in enumerate(self._modules.values()):
|
|
if i + 1 < module_count:
|
|
comma = ','
|
|
else:
|
|
comma = ''
|
|
|
|
lines.append(' %s%s' %(repr(module), comma))
|
|
|
|
lines.append(')')
|
|
|
|
return '\n'.join(lines)
|
|
else:
|
|
return 'Aggregate()'
|