forked from phoenix-oss/llama-stack-mirror
Allow passing provider_registry to resolve_impls()
This commit is contained in:
parent
8a3b64d1be
commit
b7d2b83d55
3 changed files with 13 additions and 12 deletions
|
@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
|
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
|
|||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_api_providers = get_provider_registry()
|
||||
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
|
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in all_api_providers[api]:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
|
|
|
@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -276,7 +279,7 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls = asyncio.run(resolve_impls(config))
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import yaml
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
|
||||
|
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
|||
providers=chosen,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls(run_config)
|
||||
impls = await resolve_impls(run_config, get_provider_registry())
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_id = chosen[api.value][0].provider_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue