mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
donot use global state
This commit is contained in:
parent
4b6367838f
commit
26a14c1d92
5 changed files with 19 additions and 11 deletions
|
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set
|
|||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
import llama_stack.distribution.store as distribution_store
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
|
|||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
||||
dist_registry: distribution_store.Registry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
|
@ -189,6 +192,7 @@ async def resolve_impls(
|
|||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
|
@ -237,6 +241,7 @@ async def instantiate_provider(
|
|||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: distribution_store.Registry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
@ -270,7 +275,7 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import llama_stack.distribution.store as distribution_store
|
||||
|
||||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
|
@ -20,6 +22,7 @@ async def get_routing_table_impl(
|
|||
api: Api,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: distribution_store.Registry,
|
||||
) -> Any:
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
|
@ -32,7 +35,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
dist_registry: distribution_store.Registry,
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.registry: Registry = {}
|
||||
|
@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
self.registry[obj.identifier].append(obj)
|
||||
await distribution_store.REGISTRY.register(obj)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
|
|
@ -22,9 +22,6 @@ import yaml
|
|||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
|
@ -39,6 +36,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import llama_stack.distribution.store as distribution_store
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
|
@ -292,9 +292,9 @@ def main(
|
|||
)
|
||||
)
|
||||
|
||||
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
|
||||
dist_registry = distribution_store.DiskRegistry(dist_kvstore)
|
||||
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -5,5 +5,3 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .registry import DiskRegistry, Registry
|
||||
|
||||
REGISTRY: Registry = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue