From b55034c0de17365886ab815742dbc7a2d8a4bce6 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 9 Oct 2024 19:19:26 -0700 Subject: [PATCH] Another round of simplification and clarity for models/shields/memory_banks stuff --- llama_stack/apis/agents/agents.py | 12 +- .../apis/batch_inference/batch_inference.py | 3 +- llama_stack/apis/inference/inference.py | 6 +- llama_stack/apis/inspect/inspect.py | 3 +- llama_stack/apis/memory/memory.py | 27 +-- llama_stack/apis/memory_banks/memory_banks.py | 18 +- llama_stack/apis/models/models.py | 29 +-- llama_stack/apis/safety/safety.py | 6 +- llama_stack/apis/shields/shields.py | 19 +- llama_stack/apis/telemetry/telemetry.py | 3 +- llama_stack/distribution/configure.py | 193 +----------------- llama_stack/distribution/datatypes.py | 26 +-- llama_stack/distribution/resolver.py | 109 ++++++++-- llama_stack/distribution/routers/__init__.py | 5 +- .../distribution/routers/routing_tables.py | 148 ++++++++------ llama_stack/distribution/server/endpoints.py | 23 +-- .../adapters/inference/fireworks/fireworks.py | 7 + .../adapters/inference/ollama/ollama.py | 3 +- .../providers/adapters/inference/tgi/tgi.py | 54 ++++- .../adapters/inference/together/together.py | 7 + llama_stack/providers/datatypes.py | 29 +++ .../meta_reference/inference/inference.py | 40 +++- .../impls/meta_reference/memory/faiss.py | 14 +- .../tests/inference/test_inference.py | 30 ++- .../providers/tests/memory/test_memory.py | 29 ++- llama_stack/providers/tests/resolver.py | 37 ---- .../utils/inference/model_registry.py | 18 +- 27 files changed, 454 insertions(+), 444 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 6efe1b229..de710a94f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -6,7 +6,16 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod @@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel): step: Step +@runtime_checkable class Agents(Protocol): @webmethod(route="/agents/create") async def create_agent( diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 0c3132812..45a1a1593 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol +from typing import List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel): completion_message_batch: List[CompletionMessage] +@runtime_checkable class BatchInference(Protocol): @webmethod(route="/batch_inference/completion") async def batch_completion( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 13a51bc59..588dd37ca 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,7 +6,7 @@ from enum import Enum -from typing import List, Literal, Optional, Protocol, Union +from typing import List, Literal, Optional, Protocol, runtime_checkable, Union from llama_models.schema_utils import json_schema_type, webmethod @@ -177,6 +177,7 @@ class ModelStore(Protocol): def get_model(self, identifier: str) -> ModelDef: ... +@runtime_checkable class Inference(Protocol): model_store: ModelStore @@ -214,6 +215,3 @@ class Inference(Protocol): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: ... - - @webmethod(route="/inference/register_model") - async def register_model(self, model: ModelDef) -> None: ... diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index a30f39a16..1dbe80a02 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Protocol +from typing import Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel @@ -29,6 +29,7 @@ class HealthInfo(BaseModel): # TODO: add a provider level status +@runtime_checkable class Inspect(Protocol): @webmethod(route="/providers/list", method="GET") async def list_providers(self) -> Dict[str, ProviderInfo]: ... diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index c5161e864..9047820ac 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -8,7 +8,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol +from typing import List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -42,6 +42,7 @@ class MemoryBankStore(Protocol): def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ... +@runtime_checkable class Memory(Protocol): memory_bank_store: MemoryBankStore @@ -55,13 +56,6 @@ class Memory(Protocol): ttl_seconds: Optional[int] = None, ) -> None: ... - @webmethod(route="/memory/update") - async def update_documents( - self, - bank_id: str, - documents: List[MemoryBankDocument], - ) -> None: ... - @webmethod(route="/memory/query") async def query_documents( self, @@ -69,20 +63,3 @@ class Memory(Protocol): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: ... - - @webmethod(route="/memory/documents/get", method="GET") - async def get_documents( - self, - bank_id: str, - document_ids: List[str], - ) -> List[MemoryBankDocument]: ... - - @webmethod(route="/memory/documents/delete", method="DELETE") - async def delete_documents( - self, - bank_id: str, - document_ids: List[str], - ) -> None: ... - - @webmethod(route="/memory/register_memory_bank") - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 6d9f2f9f6..df116d3c2 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import List, Literal, Optional, Protocol, Union +from typing import List, Literal, Optional, Protocol, runtime_checkable, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -22,7 +22,8 @@ class MemoryBankType(Enum): class CommonDef(BaseModel): identifier: str - provider_id: Optional[str] = None + # Hack: move this out later + provider_id: str = "" @json_schema_type @@ -58,13 +59,20 @@ MemoryBankDef = Annotated[ Field(discriminator="type"), ] +MemoryBankDefWithProvider = MemoryBankDef + +@runtime_checkable class MemoryBanks(Protocol): @webmethod(route="/memory_banks/list", method="GET") - async def list_memory_banks(self) -> List[MemoryBankDef]: ... + async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ... @webmethod(route="/memory_banks/get", method="GET") - async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ... + async def get_memory_bank( + self, identifier: str + ) -> Optional[MemoryBankDefWithProvider]: ... @webmethod(route="/memory_banks/register", method="POST") - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... + async def register_memory_bank( + self, memory_bank: MemoryBankDefWithProvider + ) -> None: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 3a770af25..994c8e995 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,34 +4,39 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -@json_schema_type class ModelDef(BaseModel): identifier: str = Field( - description="A unique identifier for the model type", + description="A unique name for the model type", ) llama_model: str = Field( - description="Pointer to the core Llama family model", + description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.", ) - provider_id: Optional[str] = Field( - default=None, description="The provider instance which serves this model" + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Any additional metadata for this model", ) - # For now, we are only supporting core llama models but as soon as finetuned - # and other custom models (for example various quantizations) are allowed, there - # will be more metadata fields here +@json_schema_type +class ModelDefWithProvider(ModelDef): + provider_id: str = Field( + description="The provider ID for this model", + ) + + +@runtime_checkable class Models(Protocol): @webmethod(route="/models/list", method="GET") - async def list_models(self) -> List[ModelDef]: ... + async def list_models(self) -> List[ModelDefWithProvider]: ... @webmethod(route="/models/get", method="GET") - async def get_model(self, identifier: str) -> Optional[ModelDef]: ... + async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ... @webmethod(route="/models/register", method="POST") - async def register_model(self, model: ModelDef) -> None: ... + async def register_model(self, model: ModelDefWithProvider) -> None: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 4f4a49407..f3615dc4b 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Protocol +from typing import Any, Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel @@ -42,6 +42,7 @@ class ShieldStore(Protocol): def get_shield(self, identifier: str) -> ShieldDef: ... +@runtime_checkable class Safety(Protocol): shield_store: ShieldStore @@ -49,6 +50,3 @@ class Safety(Protocol): async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... - - @webmethod(route="/safety/register_shield") - async def register_shield(self, shield: ShieldDef) -> None: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index cec82516e..7f003faa2 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -26,21 +26,26 @@ class ShieldDef(BaseModel): type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) - provider_id: Optional[str] = Field( - default=None, description="The provider instance which serves this shield" - ) params: Dict[str, Any] = Field( default_factory=dict, description="Any additional parameters needed for this shield", ) +@json_schema_type +class ShieldDefWithProvider(ShieldDef): + provider_id: str = Field( + description="The provider ID for this shield type", + ) + + +@runtime_checkable class Shields(Protocol): @webmethod(route="/shields/list", method="GET") - async def list_shields(self) -> List[ShieldDef]: ... + async def list_shields(self) -> List[ShieldDefWithProvider]: ... @webmethod(route="/shields/get", method="GET") - async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ... + async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ... @webmethod(route="/shields/register", method="POST") - async def register_shield(self, shield: ShieldDef) -> None: ... + async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 2546c1ede..8374192f2 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Literal, Optional, Protocol, Union +from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -123,6 +123,7 @@ Event = Annotated[ ] +@runtime_checkable class Telemetry(Protocol): @webmethod(route="/telemetry/log_event") async def log_event(self, event: Event) -> None: ... diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2811d4142..7b8c32665 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -7,17 +7,7 @@ import textwrap from typing import Any -from llama_models.sku_list import ( - llama3_1_family, - llama3_2_family, - llama3_family, - resolve_model, - safety_models, -) - from llama_stack.distribution.datatypes import * # noqa: F403 -from prompt_toolkit import prompt -from prompt_toolkit.validation import Validator from termcolor import cprint from llama_stack.distribution.distribution import ( @@ -33,11 +23,6 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 -ALLOWED_MODELS = ( - llama3_family() + llama3_1_family() + llama3_2_family() + safety_models() -) - - def configure_single_provider( registry: Dict[str, ProviderSpec], provider: Provider ) -> Provider: @@ -133,137 +118,10 @@ def configure_api_providers( config.providers[api_str] = updated_providers - if is_nux: - print( - textwrap.dedent( - """ - ========================================================================================= - Now let's configure the `objects` you will be serving via the stack. These are: - - - Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct) - - Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B) - - Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores) - - This wizard will guide you through setting up one of each of these objects. You can - always add more later by editing the run.yaml file. - """ - ) - ) - - object_types = { - "models": (ModelDef, configure_models, "inference"), - "shields": (ShieldDef, configure_shields, "safety"), - "memory_banks": (MemoryBankDef, configure_memory_banks, "memory"), - } - safety_providers = config.providers.get("safety", []) - - for otype, (odef, config_method, api_str) in object_types.items(): - existing_objects = getattr(config, otype) - - if existing_objects: - cprint( - f"{len(existing_objects)} {otype} exist. Skipping...", - "blue", - attrs=["bold"], - ) - updated_objects = existing_objects - else: - providers = config.providers.get(api_str, []) - if not providers: - updated_objects = [] - else: - # we are newly configuring this API - cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"]) - updated_objects = config_method( - config.providers[api_str], safety_providers - ) - - setattr(config, otype, updated_objects) - print("") - return config -def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]: - if not safety_providers: - return None - - provider = safety_providers[0] - assert provider.provider_type == "meta-reference" - - cfg = provider.config["llama_guard_shield"] - if not cfg: - return None - return cfg["model"] - - -def configure_models( - providers: List[Provider], safety_providers: List[Provider] -) -> List[ModelDef]: - model = prompt( - "> Please enter the model you want to serve: ", - default="Llama3.2-1B-Instruct", - validator=Validator.from_callable( - lambda x: resolve_model(x) is not None, - error_message="Model must be: {}".format( - [x.descriptor() for x in ALLOWED_MODELS] - ), - ), - ) - model = ModelDef( - identifier=model, - llama_model=model, - provider_id=providers[0].provider_id, - ) - - ret = [model] - if llama_guard := get_llama_guard_model(safety_providers): - ret.append( - ModelDef( - identifier=llama_guard, - llama_model=llama_guard, - provider_id=providers[0].provider_id, - ) - ) - - return ret - - -def configure_shields( - providers: List[Provider], safety_providers: List[Provider] -) -> List[ShieldDef]: - if get_llama_guard_model(safety_providers): - return [ - ShieldDef( - identifier="llama_guard", - type="llama_guard", - provider_id=providers[0].provider_id, - params={}, - ) - ] - - return [] - - -def configure_memory_banks( - providers: List[Provider], safety_providers: List[Provider] -) -> List[MemoryBankDef]: - bank_name = prompt( - "> Please enter a name for your memory bank: ", - default="my-memory-bank", - ) - - return [ - VectorMemoryBankDef( - identifier=bank_name, - provider_id=providers[0].provider_id, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ) - ] - - -def upgrade_from_routing_table_to_registry( +def upgrade_from_routing_table( config_dict: Dict[str, Any], ) -> Dict[str, Any]: def get_providers(entries): @@ -281,57 +139,12 @@ def upgrade_from_routing_table_to_registry( ] providers_by_api = {} - models = [] - shields = [] - memory_banks = [] routing_table = config_dict.get("routing_table", {}) for api_str, entries in routing_table.items(): providers = get_providers(entries) providers_by_api[api_str] = providers - if api_str == "inference": - for entry, provider in zip(entries, providers): - key = entry["routing_key"] - keys = key if isinstance(key, list) else [key] - for key in keys: - models.append( - ModelDef( - identifier=key, - provider_id=provider.provider_id, - llama_model=key, - ) - ) - elif api_str == "safety": - for entry, provider in zip(entries, providers): - key = entry["routing_key"] - keys = key if isinstance(key, list) else [key] - for key in keys: - shields.append( - ShieldDef( - identifier=key, - type=ShieldType.llama_guard.value, - provider_id=provider.provider_id, - ) - ) - elif api_str == "memory": - for entry, provider in zip(entries, providers): - key = entry["routing_key"] - keys = key if isinstance(key, list) else [key] - for key in keys: - # we currently only support Vector memory banks so this is OK - memory_banks.append( - VectorMemoryBankDef( - identifier=key, - provider_id=provider.provider_id, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ) - ) - config_dict["models"] = models - config_dict["shields"] = shields - config_dict["memory_banks"] = memory_banks - provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {})) if provider_map: for api_str, provider in provider_map.items(): @@ -361,9 +174,9 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi if version == LLAMA_STACK_RUN_CONFIG_VERSION: return StackRunConfig(**config_dict) - if "models" not in config_dict: + if "routing_table" in config_dict: print("Upgrading config...") - config_dict = upgrade_from_routing_table_to_registry(config_dict) + config_dict = upgrade_from_routing_table(config_dict) config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION config_dict["built_at"] = datetime.now().isoformat() diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e09a6939c..5212e6da1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -32,6 +32,12 @@ RoutableObject = Union[ MemoryBankDef, ] +RoutableObjectWithProvider = Union[ + ModelDefWithProvider, + ShieldDef, + MemoryBankDef, +] + RoutedProtocol = Union[ Inference, Safety, @@ -63,7 +69,6 @@ class RoutingTableProviderSpec(ProviderSpec): docker_image: Optional[str] = None router_api: Api - registry: List[RoutableObject] module: str pip_packages: List[str] = Field(default_factory=list) @@ -121,25 +126,6 @@ can be instantiated multiple times (with different configs) if necessary. """, ) - models: List[ModelDef] = Field( - description=""" -List of model definitions to serve. This list may get extended by -/models/register API calls at runtime. -""", - ) - shields: List[ShieldDef] = Field( - description=""" -List of shield definitions to serve. This list may get extended by -/shields/register API calls at runtime. -""", - ) - memory_banks: List[MemoryBankDef] = Field( - description=""" -List of memory bank definitions to serve. This list may get extended by -/memory_banks/register API calls at runtime. -""", - ) - class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 857eef757..1de52817f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -4,10 +4,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import importlib +import inspect from typing import Any, Dict, List, Set +from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.apis.agents import Agents +from llama_stack.apis.inference import Inference +from llama_stack.apis.inspect import Inspect +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.models import Models +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shields +from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, @@ -15,6 +27,28 @@ from llama_stack.distribution.distribution import ( from llama_stack.distribution.utils.dynamic import instantiate_class_type +def api_protocol_map() -> Dict[Api, Any]: + return { + Api.agents: Agents, + Api.inference: Inference, + Api.inspect: Inspect, + Api.memory: Memory, + Api.memory_banks: MemoryBanks, + Api.models: Models, + Api.safety: Safety, + Api.shields: Shields, + Api.telemetry: Telemetry, + } + + +def additional_protocols_map() -> Dict[Api, Any]: + return { + Api.inference: ModelsProtocolPrivate, + Api.memory: MemoryBanksProtocolPrivate, + Api.safety: ShieldsProtocolPrivate, + } + + # TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF! class ProviderWithSpec(Provider): spec: ProviderSpec @@ -73,17 +107,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An available_providers = providers_with_specs[f"inner-{info.router_api.value}"] - inner_deps = [] - registry = getattr(run_config, info.routing_table_api.value) - for entry in registry: - if entry.provider_id not in available_providers: - raise ValueError( - f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}" - ) - - provider = available_providers[entry.provider_id] - inner_deps.extend(provider.spec.api_dependencies) - providers_with_specs[info.routing_table_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__builtin__", @@ -92,13 +115,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An spec=RoutingTableProviderSpec( api=info.routing_table_api, router_api=info.router_api, - registry=registry, module="llama_stack.distribution.routers", - api_dependencies=inner_deps, - deps__=( - [x.value for x in inner_deps] - + [f"inner-{info.router_api.value}"] - ), + api_dependencies=[], + deps__=([f"inner-{info.router_api.value}"]), ), ) } @@ -212,6 +231,9 @@ async def instantiate_provider( deps: Dict[str, Any], inner_impls: Dict[str, Any], ): + protocols = api_protocol_map() + additional_protocols = additional_protocols_map() + provider_spec = provider.spec module = importlib.import_module(provider_spec.module) @@ -234,7 +256,7 @@ async def instantiate_provider( method = "get_routing_table_impl" config = None - args = [provider_spec.api, provider_spec.registry, inner_impls, deps] + args = [provider_spec.api, inner_impls, deps] else: method = "get_provider_impl" @@ -247,4 +269,55 @@ async def instantiate_provider( impl.__provider_id__ = provider.provider_id impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + + check_protocol_compliance(impl, protocols[provider_spec.api]) + if ( + not isinstance(provider_spec, AutoRoutedProviderSpec) + and provider_spec.api in additional_protocols + ): + additional_api = additional_protocols[provider_spec.api] + check_protocol_compliance(impl, additional_api) + return impl + + +def check_protocol_compliance(obj: Any, protocol: Any) -> None: + missing_methods = [] + + mro = type(obj).__mro__ + for name, value in inspect.getmembers(protocol): + if inspect.isfunction(value) and hasattr(value, "__webmethod__"): + if not hasattr(obj, name): + missing_methods.append((name, "missing")) + elif not callable(getattr(obj, name)): + missing_methods.append((name, "not_callable")) + else: + # Check if the method signatures are compatible + obj_method = getattr(obj, name) + proto_sig = inspect.signature(value) + obj_sig = inspect.signature(obj_method) + + proto_params = set(proto_sig.parameters) + proto_params.discard("self") + obj_params = set(obj_sig.parameters) + obj_params.discard("self") + if not (proto_params <= obj_params): + print( + f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}" + ) + missing_methods.append((name, "signature_mismatch")) + else: + # Check if the method is actually implemented in the class + method_owner = next( + (cls for cls in mro if name in cls.__dict__), None + ) + if ( + method_owner is None + or method_owner.__name__ == protocol.__name__ + ): + missing_methods.append((name, "not_actually_implemented")) + + if missing_methods: + raise ValueError( + f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}" + ) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 9935ecd7d..28851390c 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, List +from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 from .routing_tables import ( @@ -16,7 +16,6 @@ from .routing_tables import ( async def get_routing_table_impl( api: Api, - registry: List[RoutableObject], impls_by_provider_id: Dict[str, RoutedProtocol], _deps, ) -> Any: @@ -28,7 +27,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](registry, impls_by_provider_id) + impl = api_to_tables[api.value](impls_by_provider_id) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7cb6e8432..17755f0e4 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -29,115 +29,145 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_memory_bank(obj) +Registry = Dict[str, List[RoutableObjectWithProvider]] + + # TODO: this routing table maintains state in memory purely. We need to # add persistence to it when we add dynamic registration of objects. class CommonRoutingTableImpl(RoutingTable): def __init__( self, - registry: List[RoutableObject], impls_by_provider_id: Dict[str, RoutedProtocol], ) -> None: - for obj in registry: - if obj.provider_id not in impls_by_provider_id: - print(f"{impls_by_provider_id=}") - raise ValueError( - f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found" - ) - self.impls_by_provider_id = impls_by_provider_id - self.registry = registry - for p in self.impls_by_provider_id.values(): + async def initialize(self) -> None: + self.registry: Registry = {} + + def add_objects(objs: List[RoutableObjectWithProvider]) -> None: + for obj in objs: + if obj.identifier not in self.registry: + self.registry[obj.identifier] = [] + + self.registry[obj.identifier].append(obj) + + for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) if api == Api.inference: p.model_store = self + models = await p.list_models() + add_objects( + [ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models] + ) + elif api == Api.safety: p.shield_store = self + shields = await p.list_shields() + add_objects( + [ + ShieldDefWithProvider(**s.dict(), provider_id=pid) + for s in shields + ] + ) + elif api == Api.memory: p.memory_bank_store = self + memory_banks = await p.list_memory_banks() - self.routing_key_to_object = {} - for obj in self.registry: - self.routing_key_to_object[obj.identifier] = obj + # do in-memory updates due to pesky Annotated unions + for m in memory_banks: + m.provider_id = pid - async def initialize(self) -> None: - for obj in self.registry: - p = self.impls_by_provider_id[obj.provider_id] - await register_object_with_provider(obj, p) + add_objects(memory_banks) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str) -> Any: - if routing_key not in self.routing_key_to_object: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: + if routing_key not in self.registry: raise ValueError(f"`{routing_key}` not registered") - obj = self.routing_key_to_object[routing_key] + objs = self.registry[routing_key] + for obj in objs: + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] + + raise ValueError(f"Provider not found for `{routing_key}`") + + def get_object_by_identifier( + self, identifier: str + ) -> Optional[RoutableObjectWithProvider]: + objs = self.registry.get(identifier, []) + if not objs: + return None + + # kind of ill-defined behavior here, but we'll just return the first one + return objs[0] + + async def register_object(self, obj: RoutableObjectWithProvider): + entries = self.registry.get(obj.identifier, []) + for entry in entries: + if entry.provider_id == obj.provider_id: + print(f"`{obj.identifier}` already registered with `{obj.provider_id}`") + return + if obj.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider `{obj.provider_id}` not found") - return self.impls_by_provider_id[obj.provider_id] - - def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]: - for obj in self.registry: - if obj.identifier == identifier: - return obj - return None - - async def register_object(self, obj: RoutableObject): - if obj.identifier in self.routing_key_to_object: - print(f"`{obj.identifier}` is already registered") - return - - if not obj.provider_id: - provider_ids = list(self.impls_by_provider_id.keys()) - if not provider_ids: - raise ValueError("No providers found") - - print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}") - obj.provider_id = provider_ids[0] - else: - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") - p = self.impls_by_provider_id[obj.provider_id] await register_object_with_provider(obj, p) - self.routing_key_to_object[obj.identifier] = obj - self.registry.append(obj) + if obj.identifier not in self.registry: + self.registry[obj.identifier] = [] + self.registry[obj.identifier].append(obj) # TODO: persist this to a store class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def list_models(self) -> List[ModelDef]: - return self.registry + async def list_models(self) -> List[ModelDefWithProvider]: + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects - async def get_model(self, identifier: str) -> Optional[ModelDef]: + async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: return self.get_object_by_identifier(identifier) - async def register_model(self, model: ModelDef) -> None: + async def register_model(self, model: ModelDefWithProvider) -> None: await self.register_object(model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: - return self.registry + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects - async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: + async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: return self.get_object_by_identifier(shield_type) - async def register_shield(self, shield: ShieldDef) -> None: + async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): - async def list_memory_banks(self) -> List[MemoryBankDef]: - return self.registry + async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects - async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: + async def get_memory_bank( + self, identifier: str + ) -> Optional[MemoryBankDefWithProvider]: return self.get_object_by_identifier(identifier) - async def register_memory_bank(self, bank: MemoryBankDef) -> None: - await self.register_object(bank) + async def register_memory_bank( + self, memory_bank: MemoryBankDefWithProvider + ) -> None: + await self.register_object(memory_bank) diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 601e80e5d..93432abe1 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -9,15 +9,7 @@ from typing import Dict, List from pydantic import BaseModel -from llama_stack.apis.agents import Agents -from llama_stack.apis.inference import Inference -from llama_stack.apis.inspect import Inspect -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks -from llama_stack.apis.models import Models -from llama_stack.apis.safety import Safety -from llama_stack.apis.shields import Shields -from llama_stack.apis.telemetry import Telemetry +from llama_stack.distribution.resolver import api_protocol_map from llama_stack.providers.datatypes import Api @@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel): def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: apis = {} - protocols = { - Api.inference: Inference, - Api.safety: Safety, - Api.agents: Agents, - Api.memory: Memory, - Api.telemetry: Telemetry, - Api.models: Models, - Api.shields: Shields, - Api.memory_banks: MemoryBanks, - Api.inspect: Inspect, - } - + protocols = api_protocol_map() for api, protocol in protocols.items(): endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index c0edc836a..c85ee00f9 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): "stream": request.stream, **options, } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index fe5e39c30..7f8046202 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, @@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = { } -class OllamaInferenceAdapter(Inference): +class OllamaInferenceAdapter(Inference, Models): def __init__(self, url: str) -> None: self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 59eb7f3f1..e939bed62 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -6,14 +6,18 @@ import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model +from llama_models.sku_list import all_registered_models from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 + +from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate + from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, @@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl logger = logging.getLogger(__name__) -class _HfAdapter(Inference): +class _HfAdapter(Inference, ModelsProtocolPrivate): client: AsyncInferenceClient max_tokens: int model_id: str def __init__(self) -> None: self.formatter = ChatFormat(Tokenizer.get_instance()) + self.huggingface_repo_to_llama_model_id = { + model.huggingface_repo: model.descriptor() + for model in all_registered_models() + if model.huggingface_repo + } async def register_model(self, model: ModelDef) -> None: - resolved_model = resolve_model(model.identifier) - if resolved_model is None: - raise ValueError(f"Unknown model: {model.identifier}") + raise ValueError("Model registration is not supported for HuggingFace models") - if not resolved_model.huggingface_repo: - raise ValueError( - f"Model {model.identifier} does not have a HuggingFace repo" + async def list_models(self) -> List[ModelDef]: + repo = self.model_id + identifier = self.huggingface_repo_to_llama_model_id[repo] + return [ + ModelDef( + identifier=identifier, + llama_model=identifier, + metadata={ + "huggingface_repo": repo, + }, ) + ] - if self.model_id != resolved_model.huggingface_repo: - raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}") + async def get_model(self, identifier: str) -> Optional[ModelDef]: + model = self.huggingface_repo_to_llama_model_id.get(self.model_id) + if model != identifier: + return None + + return ModelDef( + identifier=model, + llama_model=model, + metadata={ + "huggingface_repo": self.model_id, + }, + ) async def shutdown(self) -> None: pass @@ -145,6 +170,13 @@ class _HfAdapter(Inference): **options, ) + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() + class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 0ef5bc593..3231f4657 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -134,3 +134,10 @@ class TogetherInferenceAdapter( "stream": request.stream, **get_sampling_options(request), } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 44ecb5355..5c782287e 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.apis.memory_banks import MemoryBankDef + +from llama_stack.apis.models import ModelDef +from llama_stack.apis.shields import ShieldDef + @json_schema_type class Api(Enum): @@ -28,6 +33,30 @@ class Api(Enum): inspect = "inspect" +class ModelsProtocolPrivate(Protocol): + async def list_models(self) -> List[ModelDef]: ... + + async def get_model(self, identifier: str) -> Optional[ModelDef]: ... + + async def register_model(self, model: ModelDef) -> None: ... + + +class ShieldsProtocolPrivate(Protocol): + async def list_shields(self) -> List[ShieldDef]: ... + + async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ... + + async def register_shield(self, shield: ShieldDef) -> None: ... + + +class MemoryBanksProtocolPrivate(Protocol): + async def list_memory_banks(self) -> List[MemoryBankDef]: ... + + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ... + + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index bda5e54c1..26036350e 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) @@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference): +class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config model = resolve_model(config.model) @@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference): self.generator.start() async def register_model(self, model: ModelDef) -> None: - if model.identifier != self.model.descriptor(): - raise RuntimeError( - f"Model mismatch: {model.identifier} != {self.model.descriptor()}" + raise ValueError("Dynamic model registration is not supported") + + async def list_models(self) -> List[ModelDef]: + return [ + ModelDef( + identifier=self.model.descriptor(), + llama_model=self.model.descriptor(), ) + ] + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + if self.model.descriptor() != identifier: + return None + + return ModelDef( + identifier=self.model.descriptor(), + llama_model=self.model.descriptor(), + ) async def shutdown(self) -> None: self.generator.stop() + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + raise NotImplementedError() + def chat_completion( self, model: str, @@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference): stop_reason=stop_reason, ) ) + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 7c59f5d59..adac03342 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -15,6 +15,8 @@ from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate + from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, @@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory): +class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} @@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory): ) self.cache[memory_bank.identifier] = index + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: + banks = await self.list_memory_banks() + for bank in banks: + if bank.identifier == identifier: + return bank + return None + + async def list_memory_banks(self) -> List[MemoryBankDef]: + return [i.bank for i in self.cache.values()] + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index f52de0df1..156fde2dd 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test # ```bash # PROVIDER_ID= \ # PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/memory/test_inference.py \ +# pytest -s llama_stack/providers/tests/inference/test_inference.py \ # --tb=short --disable-warnings # ``` @@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str): scope="session", params=[ {"model": Llama_8B}, - {"model": Llama_3B}, + # {"model": Llama_3B}, ], ids=lambda d: d["model"], ) @@ -64,16 +64,11 @@ async def inference_settings(request): model = request.param["model"] impls = await resolve_impls_for_test( Api.inference, - models=[ - ModelDef( - identifier=model, - llama_model=model, - ) - ], ) return { "impl": impls[Api.inference], + "models_impl": impls[Api.models], "common_params": { "model": model, "tool_choice": ToolChoice.auto, @@ -108,6 +103,25 @@ def sample_tool_definition(): ) +@pytest.mark.asyncio +async def test_model_list(inference_settings): + params = inference_settings["common_params"] + models_impl = inference_settings["models_impl"] + response = await models_impl.list_models() + assert isinstance(response, list) + assert len(response) >= 1 + assert all(isinstance(model, ModelDefWithProvider) for model in response) + + model_def = None + for model in response: + if model.identifier == params["model"]: + model_def = model + break + + assert model_def is not None + assert model_def.identifier == params["model"] + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 70f8ba4aa..2566199ae 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import pytest import pytest_asyncio @@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test @pytest_asyncio.fixture(scope="session") -async def memory_impl(): +async def memory_settings(): impls = await resolve_impls_for_test( Api.memory, - memory_banks=[], ) - return impls[Api.memory] + return { + "memory_impl": impls[Api.memory], + "memory_banks_impl": impls[Api.memory_banks], + } @pytest.fixture @@ -64,23 +67,35 @@ def sample_documents(): ] -async def register_memory_bank(memory_impl: Memory): +async def register_memory_bank(banks_impl: MemoryBanks): bank = VectorMemoryBankDef( identifier="test_bank", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, + provider_id=os.environ["PROVIDER_ID"], ) - await memory_impl.register_memory_bank(bank) + await banks_impl.register_memory_bank(bank) @pytest.mark.asyncio -async def test_query_documents(memory_impl, sample_documents): +async def test_banks_list(memory_settings): + banks_impl = memory_settings["memory_banks_impl"] + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 0 + + +@pytest.mark.asyncio +async def test_query_documents(memory_settings, sample_documents): + memory_impl = memory_settings["memory_impl"] + banks_impl = memory_settings["memory_banks_impl"] + with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - await register_memory_bank(memory_impl) + await register_memory_bank(banks_impl) await memory_impl.insert_documents("test_bank", sample_documents) query1 = "programming language" diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 266f252e4..c9ae2bd81 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing async def resolve_impls_for_test( api: Api, - models: List[ModelDef] = None, - memory_banks: List[MemoryBankDef] = None, - shields: List[ShieldDef] = None, ): if "PROVIDER_CONFIG" not in os.environ: raise ValueError( @@ -47,45 +44,11 @@ async def resolve_impls_for_test( provider_id = provider["provider_id"] print(f"No provider ID specified, picking first `{provider_id}`") - models = models or [] - shields = shields or [] - memory_banks = memory_banks or [] - - models = [ - ModelDef( - **{ - **m.dict(), - "provider_id": provider_id, - } - ) - for m in models - ] - shields = [ - ShieldDef( - **{ - **s.dict(), - "provider_id": provider_id, - } - ) - for s in shields - ] - memory_banks = [ - MemoryBankDef( - **{ - **m.dict(), - "provider_id": provider_id, - } - ) - for m in memory_banks - ] run_config = dict( built_at=datetime.now(), image_name="test-fixture", apis=[api], providers={api.value: [Provider(**provider)]}, - models=models, - memory_banks=memory_banks, - shields=shields, ) run_config = parse_and_maybe_upgrade_config(run_config) impls = await resolve_impls_with_routing(run_config) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 744a89084..e48fcad42 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict +from typing import Dict, List, Optional from llama_models.sku_list import resolve_model -from llama_stack.apis.models import * # noqa: F403 +from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate -class ModelRegistryHelper: +class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, stack_to_provider_models_map: Dict[str, str]): self.stack_to_provider_models_map = stack_to_provider_models_map @@ -33,3 +33,15 @@ class ModelRegistryHelper: raise ValueError( f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" ) + + async def list_models(self) -> List[ModelDef]: + models = [] + for llama_model, provider_model in self.stack_to_provider_models_map.items(): + models.append(ModelDef(identifier=llama_model, llama_model=llama_model)) + return models + + async def get_model(self, identifier: str) -> Optional[ModelDef]: + if identifier not in self.stack_to_provider_models_map: + return None + + return ModelDef(identifier=identifier, llama_model=identifier)