add dynamic clients for all APIs (#348)

* add dynamic clients for all APIs

* fix openapi generator

* inference + memory + agents tests now pass with "remote" providers

* Add docstring which fixes openapi generator :/
This commit is contained in:
Ashwin Bharambe 2024-10-31 14:46:25 -07:00 committed by GitHub
parent f04b566c5c
commit 37b330b4ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 350 additions and 84 deletions

View file

@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
Api.scoring_functions: ScoringFunctions,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: ModelsProtocolPrivate,
Api.memory: MemoryBanksProtocolPrivate,
Api.safety: ShieldsProtocolPrivate,
Api.inference: (ModelsProtocolPrivate, Models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
}
@ -112,8 +114,6 @@ async def resolve_impls(
if info.router_api.value not in apis_to_serve:
continue
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
@ -246,14 +246,21 @@ async def instantiate_provider(
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if provider_spec.adapter:
method = "get_adapter_impl"
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -282,7 +289,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api = additional_protocols[provider_spec.api]
additional_api, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl