from collections import OrderedDict cdef class Aggregate: def __init__(Aggregate self, *modules): self._attrs = {} # attr_name, module self._modules = OrderedDict() # module_name, module self._provided = {} # keyword, module self._common = {} # method_name, method_wrapper for module in modules: self._link_module(module) def _list_modules(Aggregate self): d = {} for attr, module in self._attrs.items(): if module not in d: d[module] = set() d[module].add(attr) return d def _link_module(Aggregate self, Module module): # check for name (=type) collision if module._name in self._modules.keys(): raise ModuleCollision(module._name, None, 'name', None) # check if requirements are satisfied unsatisfied_deps = module._requires - set(self._provided.keys()) if unsatisfied_deps: raise ModuleDependencyError(module._name, unsatisfied_deps) # check for new module declaring a common method that we already provide as non-common new_commons = module._common - set(self._common.keys()) common_collisions = {nc for nc in new_commons if nc in self._attrs and not nc in self._common.keys()} if common_collisions: colliding_module_names = {self._attrs[x]._name for x in common_collisions} raise ModuleCollision(module._name, colliding_module_names, 'non-common method', common_collisions) # check for an attr collision module_attrs = {x for x in dir(module) if x[0] != '_'} attr_collisions = (module_attrs - module._common) & (set(self._attrs.keys()) | set(self._common.keys())) if attr_collisions: colliding_module_names = set() for collision in attr_collisions: if collision in self._attrs: colliding_module_names.add(self._attrs[collision]._name) if collision in self._common: colliding_module_names.add(self._common[collision]._name) raise ModuleCollision(module._name, colliding_module_names, 'attribute', attr_collisions) # check for a provided keyword collision provided_collisions = module._provides & set(self._provided.keys()) if provided_collisions: colliding_module_names = {self._provided[x]._name for x in provided_collisions} raise ModuleCollision(module._name, colliding_module_names, 'provided keyword', provided_collisions) # link the module self._modules[module._name] = module for keyword in module._provides: self._provided[keyword] = module for attr in (module_attrs - module._common): self._attrs[attr] = module # create and/or populate CommonMethod wrappers to common methods for method_name in module._common: if method_name not in module_attrs: raise CommonMethodMissing(method_name, module._name) if method_name not in self._common: self._common[method_name] = CommonMethod(method_name) self._common[method_name].link_module(module) # hand the module a reference to us module._top = self # call the module's _on_link method, if it has one if hasattr(module, '_on_link'): module._on_link() def _unlink_module(Aggregate self, str module_name): if not module_name in self._modules: raise ModuleDoesntExist(module_name) module = self._modules[module_name] # check reverse dependencies global_deps = set() for m in self._modules.values(): global_deps.update(m._requires) reverse_deps = module._provides & global_deps if reverse_deps: raise ModuleDependencyError(module_name, reverse_deps, unlink=True) # remove from all pools for aname, mod in list(self._attrs.items()): if mod._name == module_name: del self._attrs[aname] del self._modules[module_name] for ename in module._provides: del self._provided[ename] # remove _common wrappers for method_name in module._common: self._common[method_name].unlink_module(module_name) # clear _top reference module._top = None def _merge_in(Aggregate self, Aggregate other_ag): for module_name, module in other_ag._modules.items(): if module_name not in self._modules: self._link_module(module) def __getattr__(Aggregate self, str aname): if aname in self._attrs: return getattr(self._attrs[aname], aname) elif aname in self._common: return self._common[aname] else: raise AttributeError("Aggregate has no attribute '%s'" %(aname)) def __setattr__(Aggregate self, str aname, avalue): if aname not in self._attrs: raise AttributeError("Aggregate has no attribute '%s'" %(aname)) else: setattr(self._attrs[aname], aname, avalue) def _get_type(Aggregate self): return tuple(self._modules.keys()) def __repr__(Aggregate self): module_count = len(self._modules) if module_count: lines = ['Aggregate('] for i, module in enumerate(self._modules.values()): if i + 1 < module_count: comma = ',' else: comma = '' lines.append(' %s%s' %(repr(module), comma)) lines.append(')') return '\n'.join(lines) else: return 'Aggregate()'