allow providers in api_providers

This commit is contained in:
Xi Yan 2024-09-23 00:11:55 -07:00
parent f467a0482c
commit b224fcf9ab
3 changed files with 24 additions and 1 deletions

View file

@ -59,6 +59,12 @@ class GenericProviderConfig(BaseModel):
config: Dict[str, Any] config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
class RoutableProviderConfig(GenericProviderConfig): class RoutableProviderConfig(GenericProviderConfig):
routing_key: str routing_key: str
@ -263,7 +269,9 @@ this could be just a hash
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )
api_providers: Dict[str, GenericProviderConfig] = Field( api_providers: Dict[
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
] = Field(
description=""" description="""
Provider configurations for each of the APIs provided by this package. Provider configurations for each of the APIs provided by this package.
""", """,

View file

@ -307,6 +307,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
# TODO: check that these APIs are not in the routing table part of the config # TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api] providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_id not in providers: if config.provider_id not in providers:
raise ValueError( raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`" f"Unknown provider `{config.provider_id}` is not available for API `{api}`"

View file

@ -10,6 +10,17 @@ apis_to_serve:
- safety - safety
- models - models
api_providers: api_providers:
inference:
providers:
- meta-reference
- remote::ollama
memory:
providers:
- meta-reference
- remote::pgvector
safety:
providers:
- meta-reference
telemetry: telemetry:
provider_id: meta-reference provider_id: meta-reference
config: {} config: {}