diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 879738c00..d3b807d4a 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -117,18 +117,18 @@ def configure_api_providers( if api_str == "safety": # TODO: add support for other safety providers, and simplify safety provider config if p == "meta-reference": - for shield_type in MetaReferenceShieldType: - routing_entries.append( - RoutableProviderConfig( - routing_key=shield_type.value, - provider_id=p, - config=cfg.dict(), - ) + routing_entries.append( + RoutableProviderConfig( + routing_key=[s.value for s in MetaReferenceShieldType], + provider_id=p, + config=cfg.dict(), ) + ) else: cprint( - f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml", + f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.", "yellow", + attrs=["bold"], ) routing_entries.append( RoutableProviderConfig( diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 619b5b078..fa88ad5cf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -5,228 +5,16 @@ # the root directory of this source tree. from datetime import datetime -from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field - -@json_schema_type -class Api(Enum): - inference = "inference" - safety = "safety" - agents = "agents" - memory = "memory" - - telemetry = "telemetry" - - models = "models" - shields = "shields" - memory_banks = "memory_banks" +from llama_stack.providers.datatypes import * # noqa: F403 -@json_schema_type -class ApiEndpoint(BaseModel): - route: str - method: str - name: str - - -@json_schema_type -class ProviderSpec(BaseModel): - api: Api - provider_id: str - config_class: str = Field( - ..., - description="Fully-qualified classname of the config for this provider", - ) - api_dependencies: List[Api] = Field( - default_factory=list, - description="Higher-level API surfaces may depend on other providers to provide their functionality", - ) - - -class RoutingTable(Protocol): - def get_routing_keys(self) -> List[str]: ... - - def get_provider_impl(self, routing_key: str) -> Any: ... - - -class GenericProviderConfig(BaseModel): - provider_id: str - config: Dict[str, Any] - - -class PlaceholderProviderConfig(BaseModel): - """Placeholder provider config for API whose provider are defined in routing_table""" - - providers: List[str] - - -class RoutableProviderConfig(GenericProviderConfig): - routing_key: str - - -# Example: /inference, /safety -@json_schema_type -class AutoRoutedProviderSpec(ProviderSpec): - provider_id: str = "router" - config_class: str = "" - - docker_image: Optional[str] = None - routing_table_api: Api - module: str = Field( - ..., - description=""" - Fully-qualified name of the module to import. The module is expected to have: - - - `get_router_impl(config, provider_specs, deps)`: returns the router implementation - """, - ) - provider_data_validator: Optional[str] = Field( - default=None, - ) - - @property - def pip_packages(self) -> List[str]: - raise AssertionError("Should not be called on AutoRoutedProviderSpec") - - -# Example: /models, /shields -@json_schema_type -class RoutingTableProviderSpec(ProviderSpec): - provider_id: str = "routing_table" - config_class: str = "" - docker_image: Optional[str] = None - - inner_specs: List[ProviderSpec] - module: str = Field( - ..., - description=""" - Fully-qualified name of the module to import. The module is expected to have: - - - `get_router_impl(config, provider_specs, deps)`: returns the router implementation - """, - ) - pip_packages: List[str] = Field(default_factory=list) - - -@json_schema_type -class AdapterSpec(BaseModel): - adapter_id: str = Field( - ..., - description="Unique identifier for this adapter", - ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_adapter_impl(config, deps)`: returns the adapter implementation -""", - ) - pip_packages: List[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - config_class: Optional[str] = Field( - default=None, - description="Fully-qualified classname of the config for this provider", - ) - provider_data_validator: Optional[str] = Field( - default=None, - ) - - -@json_schema_type -class InlineProviderSpec(ProviderSpec): - pip_packages: List[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - docker_image: Optional[str] = Field( - default=None, - description=""" -The docker image to use for this implementation. If one is provided, pip_packages will be ignored. -If a provider depends on other providers, the dependencies MUST NOT specify a docker image. -""", - ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_provider_impl(config, deps)`: returns the local implementation -""", - ) - provider_data_validator: Optional[str] = Field( - default=None, - ) - - -class RemoteProviderConfig(BaseModel): - host: str = "localhost" - port: int - - @property - def url(self) -> str: - return f"http://{self.host}:{self.port}" - - -def remote_provider_id(adapter_id: str) -> str: - return f"remote::{adapter_id}" - - -@json_schema_type -class RemoteProviderSpec(ProviderSpec): - adapter: Optional[AdapterSpec] = Field( - default=None, - description=""" -If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. If not specified, it indicates the remote -as being "Llama Stack compatible" -""", - ) - - @property - def docker_image(self) -> Optional[str]: - return None - - @property - def module(self) -> str: - if self.adapter: - return self.adapter.module - return f"llama_stack.apis.{self.api.value}.client" - - @property - def pip_packages(self) -> List[str]: - if self.adapter: - return self.adapter.pip_packages - return [] - - @property - def provider_data_validator(self) -> Optional[str]: - if self.adapter: - return self.adapter.provider_data_validator - return None - - -# Can avoid this by using Pydantic computed_field -def remote_provider_spec( - api: Api, adapter: Optional[AdapterSpec] = None -) -> RemoteProviderSpec: - config_class = ( - adapter.config_class - if adapter and adapter.config_class - else "llama_stack.distribution.datatypes.RemoteProviderConfig" - ) - provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" - - return RemoteProviderSpec( - api=api, provider_id=provider_id, config_class=config_class, adapter=adapter - ) +LLAMA_STACK_BUILD_CONFIG_VERSION = "v1" +LLAMA_STACK_RUN_CONFIG_VERSION = "v1" @json_schema_type @@ -247,6 +35,7 @@ in the runtime configuration to help route to the correct provider.""", @json_schema_type class StackRunConfig(BaseModel): + version: str = LLAMA_STACK_RUN_CONFIG_VERSION built_at: datetime image_name: str = Field( @@ -295,6 +84,7 @@ Provider configurations for each of the APIs provided by this package. @json_schema_type class BuildConfig(BaseModel): + version: str = LLAMA_STACK_BUILD_CONFIG_VERSION name: str distribution_spec: DistributionSpec = Field( description="The distribution spec to build including API providers. " diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 5ed04a13a..990fa66d5 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -23,7 +23,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_id} does not have a validator") - val = _THREAD_LOCAL.provider_data_header_value + val = getattr(_THREAD_LOCAL, "provider_data_header_value", None) if not val: return None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py new file mode 100644 index 000000000..f7d51c64a --- /dev/null +++ b/llama_stack/distribution/resolver.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Set + +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.distribution import ( + api_providers, + builtin_automatically_routed_apis, +) +from llama_stack.distribution.utils.dynamic import instantiate_provider + + +async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: + """ + Does two things: + - flatmaps, sorts and resolves the providers in dependency order + - for each API, produces either a (local, passthrough or router) implementation + """ + all_providers = api_providers() + specs = {} + configs = {} + + for api_str, config in run_config.api_providers.items(): + api = Api(api_str) + + # TODO: check that these APIs are not in the routing table part of the config + providers = all_providers[api] + + # skip checks for API whose provider config is specified in routing_table + if isinstance(config, PlaceholderProviderConfig): + continue + + if config.provider_id not in providers: + raise ValueError( + f"Unknown provider `{config.provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[config.provider_id] + configs[api] = config + + apis_to_serve = run_config.apis_to_serve or set( + list(specs.keys()) + list(run_config.routing_table.keys()) + ) + for info in builtin_automatically_routed_apis(): + source_api = info.routing_table_api + + assert ( + source_api not in specs + ), f"Routing table API {source_api} specified in wrong place?" + assert ( + info.router_api not in specs + ), f"Auto-routed API {info.router_api} specified in wrong place?" + + if info.router_api.value not in apis_to_serve: + continue + + print("router_api", info.router_api) + if info.router_api.value not in run_config.routing_table: + raise ValueError(f"Routing table for `{source_api.value}` is not provided?") + + routing_table = run_config.routing_table[info.router_api.value] + + providers = all_providers[info.router_api] + + inner_specs = [] + inner_deps = [] + for rt_entry in routing_table: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + inner_deps.extend(providers[rt_entry.provider_id].api_dependencies) + + specs[source_api] = RoutingTableProviderSpec( + api=source_api, + module="llama_stack.distribution.routers", + api_dependencies=inner_deps, + inner_specs=inner_specs, + ) + configs[source_api] = routing_table + + specs[info.router_api] = AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=source_api, + api_dependencies=[source_api], + ) + configs[info.router_api] = {} + + sorted_specs = topological_sort(specs.values()) + print(f"Resolved {len(sorted_specs)} providers in topological order") + for spec in sorted_specs: + print(f" {spec.api}: {spec.provider_id}") + print("") + impls = {} + for spec in sorted_specs: + api = spec.api + deps = {api: impls[api] for api in spec.api_dependencies} + impl = await instantiate_provider(spec, deps, configs[api]) + + impls[api] = impl + + return impls, specs + + +def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: + by_id = {x.api: x for x in providers} + + def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): + visited.add(a.api) + + for api in a.api_dependencies: + if api not in visited: + dfs(by_id[api], visited, stack) + + stack.append(a.api) + + visited = set() + stack = [] + + for a in providers: + if a.api not in visited: + dfs(a, visited, stack) + + return [by_id[x] for x in stack] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 89db71fa7..02dc942e8 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -19,18 +19,35 @@ from llama_stack.distribution.datatypes import * # noqa: F403 class CommonRoutingTableImpl(RoutingTable): def __init__( self, - inner_impls: List[Tuple[str, Any]], + inner_impls: List[Tuple[RoutingKey, Any]], routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: - self.providers = {k: v for k, v in inner_impls} - self.routing_keys = list(self.providers.keys()) + self.unique_providers = [] + self.providers = {} + self.routing_keys = [] + + for key, impl in inner_impls: + keys = key if isinstance(key, list) else [key] + self.unique_providers.append((keys, impl)) + + for k in keys: + if k in self.providers: + raise ValueError(f"Duplicate routing key {k}") + self.providers[k] = impl + self.routing_keys.append(k) + self.routing_table_config = routing_table_config async def initialize(self) -> None: - pass + for keys, p in self.unique_providers: + spec = p.__provider_spec__ + if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: + continue + + await p.validate_routing_keys(keys) async def shutdown(self) -> None: - for p in self.providers.values(): + for _, p in self.unique_providers: await p.shutdown() def get_provider_impl(self, routing_key: str) -> Any: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 28301264c..16b1fb619 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -17,16 +17,7 @@ from collections.abc import ( from contextlib import asynccontextmanager from http import HTTPStatus from ssl import SSLError -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Dict, - get_type_hints, - List, - Optional, - Set, -) +from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional import fire import httpx @@ -48,13 +39,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import ( - api_endpoints, - api_providers, - builtin_automatically_routed_apis, -) +from llama_stack.distribution.distribution import api_endpoints from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.distribution.resolver import resolve_impls_with_routing def is_async_iterator_type(typ): @@ -289,125 +276,6 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint -def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: - by_id = {x.api: x for x in providers} - - def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): - visited.add(a.api) - - for api in a.api_dependencies: - if api not in visited: - dfs(by_id[api], visited, stack) - - stack.append(a.api) - - visited = set() - stack = [] - - for a in providers: - if a.api not in visited: - dfs(a, visited, stack) - - return [by_id[x] for x in stack] - - -def snake_to_camel(snake_str): - return "".join(word.capitalize() for word in snake_str.split("_")) - - -async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: - """ - Does two things: - - flatmaps, sorts and resolves the providers in dependency order - - for each API, produces either a (local, passthrough or router) implementation - """ - all_providers = api_providers() - specs = {} - configs = {} - - for api_str, config in run_config.api_providers.items(): - api = Api(api_str) - - # TODO: check that these APIs are not in the routing table part of the config - providers = all_providers[api] - - # skip checks for API whose provider config is specified in routing_table - if isinstance(config, PlaceholderProviderConfig): - continue - - if config.provider_id not in providers: - raise ValueError( - f"Unknown provider `{config.provider_id}` is not available for API `{api}`" - ) - specs[api] = providers[config.provider_id] - configs[api] = config - - apis_to_serve = run_config.apis_to_serve or set( - list(specs.keys()) + list(run_config.routing_table.keys()) - ) - for info in builtin_automatically_routed_apis(): - source_api = info.routing_table_api - - assert ( - source_api not in specs - ), f"Routing table API {source_api} specified in wrong place?" - assert ( - info.router_api not in specs - ), f"Auto-routed API {info.router_api} specified in wrong place?" - - if info.router_api.value not in apis_to_serve: - continue - - print("router_api", info.router_api) - if info.router_api.value not in run_config.routing_table: - raise ValueError(f"Routing table for `{source_api.value}` is not provided?") - - routing_table = run_config.routing_table[info.router_api.value] - - providers = all_providers[info.router_api] - - inner_specs = [] - inner_deps = [] - for rt_entry in routing_table: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) - inner_deps.extend(providers[rt_entry.provider_id].api_dependencies) - - specs[source_api] = RoutingTableProviderSpec( - api=source_api, - module="llama_stack.distribution.routers", - api_dependencies=inner_deps, - inner_specs=inner_specs, - ) - configs[source_api] = routing_table - - specs[info.router_api] = AutoRoutedProviderSpec( - api=info.router_api, - module="llama_stack.distribution.routers", - routing_table_api=source_api, - api_dependencies=[source_api], - ) - configs[info.router_api] = {} - - sorted_specs = topological_sort(specs.values()) - print(f"Resolved {len(sorted_specs)} providers in topological order") - for spec in sorted_specs: - print(f" {spec.api}: {spec.provider_id}") - print("") - impls = {} - for spec in sorted_specs: - api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - impl = await instantiate_provider(spec, deps, configs[api]) - - impls[api] = impl - - return impls, specs - - def main( yaml_config: str = "llamastack-run.yaml", port: int = 5000, diff --git a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml index 6a4b2e464..0a845582c 100644 --- a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml +++ b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml @@ -42,22 +42,7 @@ routing_table: config: llama_guard_shield: null prompt_guard_shield: null - routing_key: llama_guard - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: code_scanner_guard - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: injection_shield - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: jailbreak_shield + routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - provider_id: meta-reference config: {} diff --git a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml index 2969479dc..66f6cfcef 100644 --- a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml +++ b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml @@ -45,22 +45,7 @@ routing_table: config: llama_guard_shield: null prompt_guard_shield: null - routing_key: llama_guard - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: code_scanner_guard - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: injection_shield - - provider_id: meta-reference - config: - llama_guard_shield: null - prompt_guard_shield: null - routing_key: jailbreak_shield + routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - provider_id: meta-reference config: {} diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index cf4891f20..9c1db4bdb 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -12,20 +12,21 @@ from botocore.config import Config from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model + +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig -# mapping of Model SKUs to ollama models + BEDROCK_SUPPORTED_MODELS = { - "Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", + "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", + "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", + "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", } -class BedrockInferenceAdapter(Inference): +class BedrockInferenceAdapter(Inference, RoutableProviderForModels): @staticmethod def _create_bedrock_client(config: BedrockConfig) -> BaseClient: @@ -68,6 +69,9 @@ class BedrockInferenceAdapter(Inference): return boto3_session.client("bedrock-runtime", config=boto3_config) def __init__(self, config: BedrockConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS + ) self._config = config self._client = BedrockInferenceAdapter._create_bedrock_client(config) @@ -94,22 +98,6 @@ class BedrockInferenceAdapter(Inference): ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: raise NotImplementedError() - @staticmethod - def resolve_bedrock_model(model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in BEDROCK_SUPPORTED_MODELS - ), ( - f"Unsupported model: {model_name}, use one of the supported models: " - f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}" - ) - - return BEDROCK_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - @staticmethod def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: if bedrock_stop_reason == "max_tokens": @@ -350,7 +338,7 @@ class BedrockInferenceAdapter(Inference): ) -> ( AsyncGenerator ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model) + bedrock_model = self.map_to_provider_model(model) inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( sampling_params ) diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 47e1449f2..f6949cbdc 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -12,7 +12,8 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model + +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( @@ -21,6 +22,7 @@ from llama_stack.providers.utils.inference.augment_messages import ( from .config import FireworksImplConfig + FIREWORKS_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", @@ -28,8 +30,11 @@ FIREWORKS_SUPPORTED_MODELS = { } -class FireworksInferenceAdapter(Inference): +class FireworksInferenceAdapter(Inference, RoutableProviderForModels): def __init__(self, config: FireworksImplConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS + ) self.config = config tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -65,18 +70,6 @@ class FireworksInferenceAdapter(Inference): return fireworks_messages - def resolve_fireworks_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in FIREWORKS_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(FIREWORKS_SUPPORTED_MODELS.keys())}" - - return FIREWORKS_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -112,7 +105,7 @@ class FireworksInferenceAdapter(Inference): # accumulate sampling params and other options to pass to fireworks options = self.get_fireworks_chat_options(request) - fireworks_model = self.resolve_fireworks_model(request.model) + fireworks_model = self.map_to_provider_model(request.model) if not request.stream: r = await self.client.chat.completions.acreate( diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c67bb8ce1..c4d48af81 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -11,7 +11,6 @@ import httpx from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from ollama import AsyncClient @@ -19,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -29,8 +29,11 @@ OLLAMA_SUPPORTED_SKUS = { } -class OllamaInferenceAdapter(Inference): +class OllamaInferenceAdapter(Inference, RoutableProviderForModels): def __init__(self, url: str) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS + ) self.url = url tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -72,15 +75,6 @@ class OllamaInferenceAdapter(Inference): return ollama_messages - def resolve_ollama_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}" - - return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True)) - def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -120,7 +114,7 @@ class OllamaInferenceAdapter(Inference): messages = augment_messages_for_tools(request) # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) - ollama_model = self.resolve_ollama_model(request.model) + ollama_model = self.map_to_provider_model(request.model) res = await self.client.ps() need_model_pull = True diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 66f57442f..a5e5a99be 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -13,6 +13,8 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.distribution.datatypes import RoutableProvider + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, @@ -23,7 +25,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl logger = logging.getLogger(__name__) -class _HfAdapter(Inference): +class _HfAdapter(Inference, RoutableProvider): client: AsyncInferenceClient max_tokens: int model_id: str @@ -32,6 +34,11 @@ class _HfAdapter(Inference): self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) + async def validate_routing_keys(self, routing_keys: list[str]) -> None: + # these are the model names the Llama Stack will use to route requests to this provider + # perform validation here if necessary + pass + async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/adapters/inference/together/config.py b/llama_stack/providers/adapters/inference/together/config.py index 03ee047d2..e928a771d 100644 --- a/llama_stack/providers/adapters/inference/together/config.py +++ b/llama_stack/providers/adapters/inference/together/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel): default="https://api.together.xyz/v1", description="The URL for the Together AI server", ) - api_key: str = Field( - default="", + api_key: Optional[str] = Field( + default=None, description="The Together AI API Key", ) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 7053834bd..9f73a81d1 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -10,7 +10,6 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from together import Together @@ -19,9 +18,11 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from .config import TogetherImplConfig + TOGETHER_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", @@ -32,8 +33,13 @@ TOGETHER_SUPPORTED_MODELS = { } -class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): +class TogetherInferenceAdapter( + Inference, NeedsRequestProviderData, RoutableProviderForModels +): def __init__(self, config: TogetherImplConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS + ) self.config = config tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -69,18 +75,6 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): return together_messages - def resolve_together_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in TOGETHER_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(TOGETHER_SUPPORTED_MODELS.keys())}" - - return TOGETHER_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - def get_together_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -103,12 +97,15 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): ) -> AsyncGenerator: together_api_key = None - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key client = Together(api_key=together_api_key) # wrapper request to make it easier to pass around (internal only, not exposed to API) @@ -125,7 +122,7 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): # accumulate sampling params and other options to pass to together options = self.get_together_chat_options(request) - together_model = self.resolve_together_model(request.model) + together_model = self.map_to_provider_model(request.model) messages = augment_messages_for_tools(request) if not request.stream: @@ -171,17 +168,10 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): stream=True, **options, ): - if chunk.choices[0].finish_reason: - if ( - stop_reason is None and chunk.choices[0].finish_reason == "stop" - ) or ( - stop_reason is None and chunk.choices[0].finish_reason == "eos" - ): + if finish_reason := chunk.choices[0].finish_reason: + if stop_reason is None and finish_reason in ["stop", "eos"]: stop_reason = StopReason.end_of_turn - elif ( - stop_reason is None - and chunk.choices[0].finish_reason == "length" - ): + elif stop_reason is None and finish_reason == "length": stop_reason = StopReason.out_of_tokens break diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index 0a5f5bcd6..afa13111f 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -13,7 +13,7 @@ import chromadb from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 - +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, @@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class ChromaMemoryAdapter(Memory): +class ChromaMemoryAdapter(Memory, RoutableProvider): def __init__(self, url: str) -> None: print(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") @@ -93,6 +93,10 @@ class ChromaMemoryAdapter(Memory): async def shutdown(self) -> None: pass + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[chroma] Registering memory bank routing keys: {routing_keys}") + pass + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 9cf0771ab..5864aa7dc 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -5,16 +5,17 @@ # the root directory of this source tree. import uuid - from typing import List, Tuple import psycopg2 from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json -from pydantic import BaseModel -from llama_stack.apis.memory import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class PGVectorMemoryAdapter(Memory): +class PGVectorMemoryAdapter(Memory, RoutableProvider): def __init__(self, config: PGVectorConfig) -> None: print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") self.config = config @@ -160,6 +161,10 @@ class PGVectorMemoryAdapter(Memory): async def shutdown(self) -> None: pass + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[pgvector] Registering memory bank routing keys: {routing_keys}") + pass + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index a3acda1ce..814704e2c 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -4,47 +4,58 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +import logging import traceback from typing import Any, Dict, List -from .config import BedrockSafetyConfig +import boto3 + from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 -import json -import logging +from llama_stack.distribution.datatypes import RoutableProvider -import boto3 +from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -class BedrockSafetyAdapter(Safety): +SUPPORTED_SHIELD_TYPES = [ + "bedrock_guardrail", +] + + +class BedrockSafetyAdapter(Safety, RoutableProvider): def __init__(self, config: BedrockSafetyConfig) -> None: + if not config.aws_profile: + raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config async def initialize(self) -> None: - if not self.config.aws_profile: - raise RuntimeError( - f"Missing boto_client aws_profile in model info::{self.config}" - ) - try: - print(f"initializing with profile --- > {self.config}::") - self.boto_client_profile = self.config.aws_profile + print(f"initializing with profile --- > {self.config}") self.boto_client = boto3.Session( - profile_name=self.boto_client_profile + profile_name=self.config.aws_profile ).client("bedrock-runtime") except Exception as e: - raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e + raise RuntimeError("Error initializing BedrockSafetyAdapter") from e async def shutdown(self) -> None: pass + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + for key in routing_keys: + if key not in SUPPORTED_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {key}") + async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: + if shield_type not in SUPPORTED_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {shield_type}") + """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ { diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 24fcc63b1..c7a667e01 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.sku_list import resolve_model from together import Together from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -13,53 +12,52 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig + SAFETY_SHIELD_TYPES = { + "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", } -def shield_type_to_model_name(shield_type: str) -> str: - if shield_type == "llama_guard": - shield_type = "Llama-Guard-3-8B" - - model = resolve_model(shield_type) - if ( - model is None - or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES - or model.model_family is not ModelFamily.safety - ): - raise ValueError( - f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}" - ) - - return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) - - -class TogetherSafetyImpl(Safety, NeedsRequestProviderData): +class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config async def initialize(self) -> None: pass + async def shutdown(self) -> None: + pass + + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + for key in routing_keys: + if key not in SAFETY_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {key}") + async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: + if shield_type not in SAFETY_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {shield_type}") together_api_key = None - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key - model_name = shield_type_to_model_name(shield_type) + model_name = SAFETY_SHIELD_TYPES[shield_type] # messages can have role assistant or user api_messages = [] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py new file mode 100644 index 000000000..a9a3d86e9 --- /dev/null +++ b/llama_stack/providers/datatypes.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol, Union + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class Api(Enum): + inference = "inference" + safety = "safety" + agents = "agents" + memory = "memory" + + telemetry = "telemetry" + + models = "models" + shields = "shields" + memory_banks = "memory_banks" + + +@json_schema_type +class ApiEndpoint(BaseModel): + route: str + method: str + name: str + + +@json_schema_type +class ProviderSpec(BaseModel): + api: Api + provider_id: str + config_class: str = Field( + ..., + description="Fully-qualified classname of the config for this provider", + ) + api_dependencies: List[Api] = Field( + default_factory=list, + description="Higher-level API surfaces may depend on other providers to provide their functionality", + ) + + +class RoutingTable(Protocol): + def get_routing_keys(self) -> List[str]: ... + + def get_provider_impl(self, routing_key: str) -> Any: ... + + +class RoutableProvider(Protocol): + """ + A provider which sits behind the RoutingTable and can get routed to. + + All Inference / Safety / Memory providers fall into this bucket. + """ + + async def validate_routing_keys(self, keys: List[str]) -> None: ... + + +class GenericProviderConfig(BaseModel): + provider_id: str + config: Dict[str, Any] + + +class PlaceholderProviderConfig(BaseModel): + """Placeholder provider config for API whose provider are defined in routing_table""" + + providers: List[str] + + +RoutingKey = Union[str, List[str]] + + +class RoutableProviderConfig(GenericProviderConfig): + routing_key: RoutingKey + + +# Example: /inference, /safety +@json_schema_type +class AutoRoutedProviderSpec(ProviderSpec): + provider_id: str = "router" + config_class: str = "" + + docker_image: Optional[str] = None + routing_table_api: Api + module: str = Field( + ..., + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation + """, + ) + provider_data_validator: Optional[str] = Field( + default=None, + ) + + @property + def pip_packages(self) -> List[str]: + raise AssertionError("Should not be called on AutoRoutedProviderSpec") + + +# Example: /models, /shields +@json_schema_type +class RoutingTableProviderSpec(ProviderSpec): + provider_id: str = "routing_table" + config_class: str = "" + docker_image: Optional[str] = None + + inner_specs: List[ProviderSpec] + module: str = Field( + ..., + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation + """, + ) + pip_packages: List[str] = Field(default_factory=list) + + +@json_schema_type +class AdapterSpec(BaseModel): + adapter_id: str = Field( + ..., + description="Unique identifier for this adapter", + ) + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation +""", + ) + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + config_class: Optional[str] = Field( + default=None, + description="Fully-qualified classname of the config for this provider", + ) + provider_data_validator: Optional[str] = Field( + default=None, + ) + + +@json_schema_type +class InlineProviderSpec(ProviderSpec): + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + docker_image: Optional[str] = Field( + default=None, + description=""" +The docker image to use for this implementation. If one is provided, pip_packages will be ignored. +If a provider depends on other providers, the dependencies MUST NOT specify a docker image. +""", + ) + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_provider_impl(config, deps)`: returns the local implementation +""", + ) + provider_data_validator: Optional[str] = Field( + default=None, + ) + + +class RemoteProviderConfig(BaseModel): + host: str = "localhost" + port: int + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + +def remote_provider_id(adapter_id: str) -> str: + return f"remote::{adapter_id}" + + +@json_schema_type +class RemoteProviderSpec(ProviderSpec): + adapter: Optional[AdapterSpec] = Field( + default=None, + description=""" +If some code is needed to convert the remote responses into Llama Stack compatible +API responses, specify the adapter here. If not specified, it indicates the remote +as being "Llama Stack compatible" +""", + ) + + @property + def docker_image(self) -> Optional[str]: + return None + + @property + def module(self) -> str: + if self.adapter: + return self.adapter.module + return f"llama_stack.apis.{self.api.value}.client" + + @property + def pip_packages(self) -> List[str]: + if self.adapter: + return self.adapter.pip_packages + return [] + + @property + def provider_data_validator(self) -> Optional[str]: + if self.adapter: + return self.adapter.provider_data_validator + return None + + +# Can avoid this by using Pydantic computed_field +def remote_provider_spec( + api: Api, adapter: Optional[AdapterSpec] = None +) -> RemoteProviderSpec: + config_class = ( + adapter.config_class + if adapter and adapter.config_class + else "llama_stack.distribution.datatypes.RemoteProviderConfig" + ) + provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" + + return RemoteProviderSpec( + api=api, provider_id=provider_id, config_class=config_class, adapter=adapter + ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e9b790dd5..e89d8ec4c 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -6,21 +6,13 @@ import asyncio -from typing import AsyncIterator, Union +from typing import AsyncIterator, List, Union -from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - Inference, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) @@ -28,15 +20,12 @@ from llama_stack.providers.utils.inference.augment_messages import ( from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 - # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference): +class MetaReferenceInferenceImpl(Inference, RoutableProvider): def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config model = resolve_model(config.model) @@ -49,6 +38,12 @@ class MetaReferenceInferenceImpl(Inference): self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + assert ( + len(routing_keys) == 1 + ), f"Only one routing key is supported {routing_keys}" + assert routing_keys[0] == self.config.model + async def shutdown(self) -> None: self.generator.stop() diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 30b7245e6..b9a00908e 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -14,6 +14,7 @@ import numpy as np from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.utils.memory.vector_store import ( @@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory): +class FaissMemoryImpl(Memory, RoutableProvider): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} @@ -71,6 +72,10 @@ class FaissMemoryImpl(Memory): async def shutdown(self) -> None: ... + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[faiss] Registering memory bank routing keys: {routing_keys}") + pass + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6bb851596..f02574f19 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -4,13 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, List + from llama_models.sku_list import resolve_model from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import Api, RoutableProvider from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, @@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str: return model_dir -class MetaReferenceSafetyImpl(Safety): +class MetaReferenceSafetyImpl(Safety, RoutableProvider): def __init__(self, config: SafetyConfig, deps) -> None: self.config = config self.inference_api = deps[Api.inference] @@ -46,6 +48,15 @@ class MetaReferenceSafetyImpl(Safety): model_dir = resolve_and_get_path(shield_cfg.model) _ = PromptGuardShield.instance(model_dir) + async def shutdown(self) -> None: + pass + + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + available_shields = [v.value for v in MetaReferenceShieldType] + for key in routing_keys: + if key not in available_shields: + raise ValueError(f"Unknown safety shield type: {key}") + async def run_shield( self, shield_type: str, diff --git a/llama_stack/providers/utils/inference/routable.py b/llama_stack/providers/utils/inference/routable.py new file mode 100644 index 000000000..a36631208 --- /dev/null +++ b/llama_stack/providers/utils/inference/routable.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict, List + +from llama_models.sku_list import resolve_model + +from llama_stack.distribution.datatypes import RoutableProvider + + +class RoutableProviderForModels(RoutableProvider): + + def __init__(self, stack_to_provider_models_map: Dict[str, str]): + self.stack_to_provider_models_map = stack_to_provider_models_map + + async def validate_routing_keys(self, routing_keys: List[str]): + for routing_key in routing_keys: + if routing_key not in self.stack_to_provider_models_map: + raise ValueError( + f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}" + ) + + def map_to_provider_model(self, routing_key: str) -> str: + model = resolve_model(routing_key) + if not model: + raise ValueError(f"Unknown model: `{routing_key}`") + + if routing_key not in self.stack_to_provider_models_map: + raise ValueError( + f"Model {routing_key} not found in map {self.stack_to_provider_models_map}" + ) + + return self.stack_to_provider_models_map[routing_key] diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 2ae975cdc..cbe36193c 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -50,37 +50,7 @@ routing_table: disable_output_check: false prompt_guard_shield: model: Prompt-Guard-86M - routing_key: llama_guard - - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - routing_key: code_scanner_guard - - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - routing_key: injection_shield - - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - routing_key: jailbreak_shield + routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - provider_id: meta-reference config: {}