mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
donot use global state
This commit is contained in:
parent
4b6367838f
commit
26a14c1d92
5 changed files with 19 additions and 11 deletions
|
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Set
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
import llama_stack.distribution.store as distribution_store
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
run_config: StackRunConfig,
|
||||||
|
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
||||||
|
dist_registry: distribution_store.Registry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
|
@ -189,6 +192,7 @@ async def resolve_impls(
|
||||||
provider,
|
provider,
|
||||||
deps,
|
deps,
|
||||||
inner_impls,
|
inner_impls,
|
||||||
|
dist_registry,
|
||||||
)
|
)
|
||||||
# TODO: ugh slightly redesign this shady looking code
|
# TODO: ugh slightly redesign this shady looking code
|
||||||
if "inner-" in api_str:
|
if "inner-" in api_str:
|
||||||
|
@ -237,6 +241,7 @@ async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: Dict[str, Any],
|
||||||
|
dist_registry: distribution_store.Registry,
|
||||||
):
|
):
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
@ -270,7 +275,7 @@ async def instantiate_provider(
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, inner_impls, deps]
|
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
import llama_stack.distribution.store as distribution_store
|
||||||
|
|
||||||
from .routing_tables import (
|
from .routing_tables import (
|
||||||
DatasetsRoutingTable,
|
DatasetsRoutingTable,
|
||||||
MemoryBanksRoutingTable,
|
MemoryBanksRoutingTable,
|
||||||
|
@ -20,6 +22,7 @@ async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
|
dist_registry: distribution_store.Registry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"memory_banks": MemoryBanksRoutingTable,
|
"memory_banks": MemoryBanksRoutingTable,
|
||||||
|
@ -32,7 +35,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
|
dist_registry: distribution_store.Registry,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.registry: Registry = {}
|
self.registry: Registry = {}
|
||||||
|
@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if obj.identifier not in self.registry:
|
if obj.identifier not in self.registry:
|
||||||
self.registry[obj.identifier] = []
|
self.registry[obj.identifier] = []
|
||||||
self.registry[obj.identifier].append(obj)
|
self.registry[obj.identifier].append(obj)
|
||||||
await distribution_store.REGISTRY.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
@ -22,9 +22,6 @@ import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
@ -39,6 +36,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
import llama_stack.distribution.store as distribution_store
|
import llama_stack.distribution.store as distribution_store
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
@ -292,9 +292,9 @@ def main(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
|
dist_registry = distribution_store.DiskRegistry(dist_kvstore)
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -5,5 +5,3 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .registry import DiskRegistry, Registry
|
from .registry import DiskRegistry, Registry
|
||||||
|
|
||||||
REGISTRY: Registry = None
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue