mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Allow passing provider_registry to resolve_impls()
This commit is contained in:
parent
8a3b64d1be
commit
b7d2b83d55
3 changed files with 13 additions and 12 deletions
|
@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
builtin_automatically_routed_apis,
|
|
||||||
get_provider_registry,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def resolve_impls(
|
||||||
|
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_api_providers = get_provider_registry()
|
|
||||||
|
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
)
|
)
|
||||||
|
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.provider_type not in all_api_providers[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
p = all_api_providers[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
p.deps__ = [a.value for a in p.api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=p,
|
spec=p,
|
||||||
|
|
|
@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import (
|
||||||
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -276,7 +279,7 @@ def main(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config))
|
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
||||||
providers=chosen,
|
providers=chosen,
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
impls = await resolve_impls(run_config)
|
impls = await resolve_impls(run_config, get_provider_registry())
|
||||||
|
|
||||||
if "provider_data" in config_dict:
|
if "provider_data" in config_dict:
|
||||||
provider_id = chosen[api.value][0].provider_id
|
provider_id = chosen[api.value][0].provider_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue