donot use global state

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:19:54 -07:00 committed by Dinesh Yeduguru
parent 4b6367838f
commit 26a14c1d92
5 changed files with 19 additions and 11 deletions

View file

@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring # TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
dist_registry: distribution_store.Registry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
@ -189,6 +192,7 @@ async def resolve_impls(
provider, provider,
deps, deps,
inner_impls, inner_impls,
dist_registry,
) )
# TODO: ugh slightly redesign this shady looking code # TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str: if "inner-" in api_str:
@ -237,6 +241,7 @@ async def instantiate_provider(
provider: ProviderWithSpec, provider: ProviderWithSpec,
deps: Dict[str, Any], deps: Dict[str, Any],
inner_impls: Dict[str, Any], inner_impls: Dict[str, Any],
dist_registry: distribution_store.Registry,
): ):
protocols = api_protocol_map() protocols = api_protocol_map()
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()
@ -270,7 +275,7 @@ async def instantiate_provider(
method = "get_routing_table_impl" method = "get_routing_table_impl"
config = None config = None
args = [provider_spec.api, inner_impls, deps] args = [provider_spec.api, inner_impls, deps, dist_registry]
else: else:
method = "get_provider_impl" method = "get_provider_impl"

View file

@ -7,6 +7,8 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from .routing_tables import ( from .routing_tables import (
DatasetsRoutingTable, DatasetsRoutingTable,
MemoryBanksRoutingTable, MemoryBanksRoutingTable,
@ -20,6 +22,7 @@ async def get_routing_table_impl(
api: Api, api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: Dict[str, RoutedProtocol],
_deps, _deps,
dist_registry: distribution_store.Registry,
) -> Any: ) -> Any:
api_to_tables = { api_to_tables = {
"memory_banks": MemoryBanksRoutingTable, "memory_banks": MemoryBanksRoutingTable,
@ -32,7 +35,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables: if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id) impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: distribution_store.Registry,
) -> None: ) -> None:
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
self.registry: Registry = {} self.registry: Registry = {}
@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.identifier not in self.registry: if obj.identifier not in self.registry:
self.registry[obj.identifier] = [] self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj) self.registry[obj.identifier].append(obj)
await distribution_store.REGISTRY.register(obj) await self.dist_registry.register(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -22,9 +22,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
@ -39,6 +36,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -292,9 +292,9 @@ def main(
) )
) )
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore) dist_registry = distribution_store.DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry())) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])

View file

@ -5,5 +5,3 @@
# the root directory of this source tree. # the root directory of this source tree.
from .registry import DiskRegistry, Registry from .registry import DiskRegistry, Registry
REGISTRY: Registry = None