donot use global state

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:19:54 -07:00 committed by Dinesh Yeduguru
parent 4b6367838f
commit 26a14c1d92
5 changed files with 19 additions and 11 deletions

View file

@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
dist_registry: distribution_store.Registry,
) -> Dict[Api, Any]:
"""
Does two things:
@ -189,6 +192,7 @@ async def resolve_impls(
provider,
deps,
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
@ -237,6 +241,7 @@ async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[str, Any],
inner_impls: Dict[str, Any],
dist_registry: distribution_store.Registry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
@ -270,7 +275,7 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps]
args = [provider_spec.api, inner_impls, deps, dist_registry]
else:
method = "get_provider_impl"