donot use global state

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:19:54 -07:00 committed by Dinesh Yeduguru
parent 4b6367838f
commit 26a14c1d92
5 changed files with 19 additions and 11 deletions

View file

@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
dist_registry: distribution_store.Registry,
) -> Dict[Api, Any]:
"""
Does two things:
@ -189,6 +192,7 @@ async def resolve_impls(
provider,
deps,
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
@ -237,6 +241,7 @@ async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[str, Any],
inner_impls: Dict[str, Any],
dist_registry: distribution_store.Registry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
@ -270,7 +275,7 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps]
args = [provider_spec.api, inner_impls, deps, dist_registry]
else:
method = "get_provider_impl"

View file

@ -7,6 +7,8 @@
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
@ -20,6 +22,7 @@ async def get_routing_table_impl(
api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
dist_registry: distribution_store.Registry,
) -> Any:
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
@ -32,7 +35,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id)
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
await impl.initialize()
return impl

View file

@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: distribution_store.Registry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
self.registry: Registry = {}
@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
await distribution_store.REGISTRY.register(obj)
await self.dist_registry.register(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -22,9 +22,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
@ -39,6 +36,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data
@ -292,9 +292,9 @@ def main(
)
)
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
dist_registry = distribution_store.DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -5,5 +5,3 @@
# the root directory of this source tree.
from .registry import DiskRegistry, Registry
REGISTRY: Registry = None