Revert "add new resolve_impls_with_routing"

This reverts commit 34f0c11001.
This commit is contained in:
Xi Yan 2024-09-21 12:42:10 -07:00
parent cf8bd10989
commit af8ecac5f5
3 changed files with 6 additions and 18 deletions

View file

@ -209,8 +209,7 @@ class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, str] ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
ProviderRoutingTableEntry = List[ProviderRoutingEntry]
@json_schema_type @json_schema_type
@ -249,12 +248,6 @@ As examples:
The key may support wild-cards alsothe routing_key to route to the correct provider.""", The key may support wild-cards alsothe routing_key to route to the correct provider.""",
) )
provider_routing_table: Dict[str, ProviderRoutingTableEntry] = Field(
description="""
Map of API to a list of providers backing the API.
Each provider is a (routing_key, provider_config) tuple.
"""
)
@json_schema_type @json_schema_type

View file

@ -290,18 +290,18 @@ def snake_to_camel(snake_str):
async def resolve_impls_with_routing( async def resolve_impls_with_routing(
stack_run_config: StackRunConfig, stack_run_config: StackRunConfig,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
raise NotImplementedError("This is not implemented yet") raise NotImplementedError("This is not implemented yet")
async def resolve_impls( async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry], stack_run_config: StackRunConfig,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
provider_map = stack_run_config.provider_map
all_providers = api_providers() all_providers = api_providers()
specs = {} specs = {}
@ -349,15 +349,9 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp: with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp)) config = StackRunConfig(**yaml.safe_load(fp))
cprint(f"StackRunConfig: {config}", "blue")
app = FastAPI() app = FastAPI()
# check if routing table exists impls, specs = asyncio.run(resolve_impls(config))
if config.provider_routing_table is not None:
impls, specs = asyncio.run(resolve_impls_with_routing(config))
else:
impls, specs = asyncio.run(resolve_impls(config.provider_map))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])

View file

@ -9,7 +9,7 @@ provider_map:
# use builtin-router as dummy field # use builtin-router as dummy field
memory: builtin-router memory: builtin-router
inference: builtin-router inference: builtin-router
provider_routing_table: routing_table:
inference: inference:
- routing_key: Meta-Llama3.1-8B-Instruct - routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference provider_id: meta-reference
@ -91,3 +91,4 @@ provider_routing_table:
# api: safety # api: safety
# config: # config:
# model: Prompt-Guard-86M # model: Prompt-Guard-86M