Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name)
async def instantiate_router(
provider_spec: RouterProviderSpec,
api: str,
provider_routing_table: Dict[str, Any],
):
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_router_impl")
impl = await fn(api, provider_routing_table)
impl.__provider_spec__ = provider_spec
return impl
async def instantiate_builtin_provider(
provider_spec: BuiltinProviderSpec,
run_config: StackRunConfig,
):
print("!!! instantiate_builtin_provider")
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_builtin_impl")
impl = await fn(run_config)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
@ -60,6 +35,29 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, RoutingTableConfig)
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table.entries:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"