diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 619b5b078..e124c17cb 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -13,6 +13,10 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +LLAMA_STACK_BUILD_CONFIG_VERSION = "v1" +LLAMA_STACK_RUN_CONFIG_VERSION = "v1" + + @json_schema_type class Api(Enum): inference = "inference" @@ -54,6 +58,12 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... +class RoutableProvider(Protocol): + async def register_routing_keys(self, keys: List[str]) -> None: ... + + def get_routing_keys(self) -> List[str]: ... + + class GenericProviderConfig(BaseModel): provider_id: str config: Dict[str, Any] @@ -65,8 +75,11 @@ class PlaceholderProviderConfig(BaseModel): providers: List[str] +RoutingKey = Union[str, List[str]] + + class RoutableProviderConfig(GenericProviderConfig): - routing_key: str + routing_key: RoutingKey # Example: /inference, /safety @@ -247,6 +260,7 @@ in the runtime configuration to help route to the correct provider.""", @json_schema_type class StackRunConfig(BaseModel): + version: str = LLAMA_STACK_RUN_CONFIG_VERSION built_at: datetime image_name: str = Field( @@ -295,6 +309,7 @@ Provider configurations for each of the APIs provided by this package. @json_schema_type class BuildConfig(BaseModel): + version: str = LLAMA_STACK_BUILD_CONFIG_VERSION name: str distribution_spec: DistributionSpec = Field( description="The distribution spec to build including API providers. " diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py new file mode 100644 index 000000000..f7d51c64a --- /dev/null +++ b/llama_stack/distribution/resolver.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Set + +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.distribution import ( + api_providers, + builtin_automatically_routed_apis, +) +from llama_stack.distribution.utils.dynamic import instantiate_provider + + +async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: + """ + Does two things: + - flatmaps, sorts and resolves the providers in dependency order + - for each API, produces either a (local, passthrough or router) implementation + """ + all_providers = api_providers() + specs = {} + configs = {} + + for api_str, config in run_config.api_providers.items(): + api = Api(api_str) + + # TODO: check that these APIs are not in the routing table part of the config + providers = all_providers[api] + + # skip checks for API whose provider config is specified in routing_table + if isinstance(config, PlaceholderProviderConfig): + continue + + if config.provider_id not in providers: + raise ValueError( + f"Unknown provider `{config.provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[config.provider_id] + configs[api] = config + + apis_to_serve = run_config.apis_to_serve or set( + list(specs.keys()) + list(run_config.routing_table.keys()) + ) + for info in builtin_automatically_routed_apis(): + source_api = info.routing_table_api + + assert ( + source_api not in specs + ), f"Routing table API {source_api} specified in wrong place?" + assert ( + info.router_api not in specs + ), f"Auto-routed API {info.router_api} specified in wrong place?" + + if info.router_api.value not in apis_to_serve: + continue + + print("router_api", info.router_api) + if info.router_api.value not in run_config.routing_table: + raise ValueError(f"Routing table for `{source_api.value}` is not provided?") + + routing_table = run_config.routing_table[info.router_api.value] + + providers = all_providers[info.router_api] + + inner_specs = [] + inner_deps = [] + for rt_entry in routing_table: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + inner_deps.extend(providers[rt_entry.provider_id].api_dependencies) + + specs[source_api] = RoutingTableProviderSpec( + api=source_api, + module="llama_stack.distribution.routers", + api_dependencies=inner_deps, + inner_specs=inner_specs, + ) + configs[source_api] = routing_table + + specs[info.router_api] = AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=source_api, + api_dependencies=[source_api], + ) + configs[info.router_api] = {} + + sorted_specs = topological_sort(specs.values()) + print(f"Resolved {len(sorted_specs)} providers in topological order") + for spec in sorted_specs: + print(f" {spec.api}: {spec.provider_id}") + print("") + impls = {} + for spec in sorted_specs: + api = spec.api + deps = {api: impls[api] for api in spec.api_dependencies} + impl = await instantiate_provider(spec, deps, configs[api]) + + impls[api] = impl + + return impls, specs + + +def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: + by_id = {x.api: x for x in providers} + + def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): + visited.add(a.api) + + for api in a.api_dependencies: + if api not in visited: + dfs(by_id[api], visited, stack) + + stack.append(a.api) + + visited = set() + stack = [] + + for a in providers: + if a.api not in visited: + dfs(a, visited, stack) + + return [by_id[x] for x in stack] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 89db71fa7..e0b778345 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -19,18 +19,31 @@ from llama_stack.distribution.datatypes import * # noqa: F403 class CommonRoutingTableImpl(RoutingTable): def __init__( self, - inner_impls: List[Tuple[str, Any]], + inner_impls: List[Tuple[RoutingKey, Any]], routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: - self.providers = {k: v for k, v in inner_impls} - self.routing_keys = list(self.providers.keys()) + self.unique_providers = [] + self.providers = {} + self.routing_keys = [] + + for key, impl in inner_impls: + keys = key if isinstance(key, list) else [key] + self.unique_providers.append((keys, impl)) + + for k in keys: + if k in self.providers: + raise ValueError(f"Duplicate routing key {k}") + self.providers[k] = impl + self.routing_keys.append(k) + self.routing_table_config = routing_table_config async def initialize(self) -> None: - pass + for keys, p in self.unique_providers: + await p.register_routing_keys(keys) async def shutdown(self) -> None: - for p in self.providers.values(): + for _, p in self.unique_providers: await p.shutdown() def get_provider_impl(self, routing_key: str) -> Any: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 28301264c..16b1fb619 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -17,16 +17,7 @@ from collections.abc import ( from contextlib import asynccontextmanager from http import HTTPStatus from ssl import SSLError -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Dict, - get_type_hints, - List, - Optional, - Set, -) +from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional import fire import httpx @@ -48,13 +39,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import ( - api_endpoints, - api_providers, - builtin_automatically_routed_apis, -) +from llama_stack.distribution.distribution import api_endpoints from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.distribution.resolver import resolve_impls_with_routing def is_async_iterator_type(typ): @@ -289,125 +276,6 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint -def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: - by_id = {x.api: x for x in providers} - - def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): - visited.add(a.api) - - for api in a.api_dependencies: - if api not in visited: - dfs(by_id[api], visited, stack) - - stack.append(a.api) - - visited = set() - stack = [] - - for a in providers: - if a.api not in visited: - dfs(a, visited, stack) - - return [by_id[x] for x in stack] - - -def snake_to_camel(snake_str): - return "".join(word.capitalize() for word in snake_str.split("_")) - - -async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: - """ - Does two things: - - flatmaps, sorts and resolves the providers in dependency order - - for each API, produces either a (local, passthrough or router) implementation - """ - all_providers = api_providers() - specs = {} - configs = {} - - for api_str, config in run_config.api_providers.items(): - api = Api(api_str) - - # TODO: check that these APIs are not in the routing table part of the config - providers = all_providers[api] - - # skip checks for API whose provider config is specified in routing_table - if isinstance(config, PlaceholderProviderConfig): - continue - - if config.provider_id not in providers: - raise ValueError( - f"Unknown provider `{config.provider_id}` is not available for API `{api}`" - ) - specs[api] = providers[config.provider_id] - configs[api] = config - - apis_to_serve = run_config.apis_to_serve or set( - list(specs.keys()) + list(run_config.routing_table.keys()) - ) - for info in builtin_automatically_routed_apis(): - source_api = info.routing_table_api - - assert ( - source_api not in specs - ), f"Routing table API {source_api} specified in wrong place?" - assert ( - info.router_api not in specs - ), f"Auto-routed API {info.router_api} specified in wrong place?" - - if info.router_api.value not in apis_to_serve: - continue - - print("router_api", info.router_api) - if info.router_api.value not in run_config.routing_table: - raise ValueError(f"Routing table for `{source_api.value}` is not provided?") - - routing_table = run_config.routing_table[info.router_api.value] - - providers = all_providers[info.router_api] - - inner_specs = [] - inner_deps = [] - for rt_entry in routing_table: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) - inner_deps.extend(providers[rt_entry.provider_id].api_dependencies) - - specs[source_api] = RoutingTableProviderSpec( - api=source_api, - module="llama_stack.distribution.routers", - api_dependencies=inner_deps, - inner_specs=inner_specs, - ) - configs[source_api] = routing_table - - specs[info.router_api] = AutoRoutedProviderSpec( - api=info.router_api, - module="llama_stack.distribution.routers", - routing_table_api=source_api, - api_dependencies=[source_api], - ) - configs[info.router_api] = {} - - sorted_specs = topological_sort(specs.values()) - print(f"Resolved {len(sorted_specs)} providers in topological order") - for spec in sorted_specs: - print(f" {spec.api}: {spec.provider_id}") - print("") - impls = {} - for spec in sorted_specs: - api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - impl = await instantiate_provider(spec, deps, configs[api]) - - impls[api] = impl - - return impls, specs - - def main( yaml_config: str = "llamastack-run.yaml", port: int = 5000, diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index cf4891f20..14b506964 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -12,7 +12,8 @@ from botocore.config import Config from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model + +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig @@ -25,7 +26,7 @@ BEDROCK_SUPPORTED_MODELS = { } -class BedrockInferenceAdapter(Inference): +class BedrockInferenceAdapter(Inference, RoutableProviderForModels): @staticmethod def _create_bedrock_client(config: BedrockConfig) -> BaseClient: @@ -68,6 +69,9 @@ class BedrockInferenceAdapter(Inference): return boto3_session.client("bedrock-runtime", config=boto3_config) def __init__(self, config: BedrockConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS + ) self._config = config self._client = BedrockInferenceAdapter._create_bedrock_client(config) @@ -94,22 +98,6 @@ class BedrockInferenceAdapter(Inference): ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: raise NotImplementedError() - @staticmethod - def resolve_bedrock_model(model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in BEDROCK_SUPPORTED_MODELS - ), ( - f"Unsupported model: {model_name}, use one of the supported models: " - f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}" - ) - - return BEDROCK_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - @staticmethod def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: if bedrock_stop_reason == "max_tokens": @@ -350,7 +338,7 @@ class BedrockInferenceAdapter(Inference): ) -> ( AsyncGenerator ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model) + bedrock_model = self.map_to_provider_model(model) inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( sampling_params ) diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 47e1449f2..df8cee189 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -12,7 +12,8 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model + +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( @@ -28,8 +29,11 @@ FIREWORKS_SUPPORTED_MODELS = { } -class FireworksInferenceAdapter(Inference): +class FireworksInferenceAdapter(Inference, RoutableProviderForModels): def __init__(self, config: FireworksImplConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS + ) self.config = config tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -65,18 +69,6 @@ class FireworksInferenceAdapter(Inference): return fireworks_messages - def resolve_fireworks_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in FIREWORKS_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(FIREWORKS_SUPPORTED_MODELS.keys())}" - - return FIREWORKS_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -112,7 +104,7 @@ class FireworksInferenceAdapter(Inference): # accumulate sampling params and other options to pass to fireworks options = self.get_fireworks_chat_options(request) - fireworks_model = self.resolve_fireworks_model(request.model) + fireworks_model = self.map_to_provider_model(request.model) if not request.stream: r = await self.client.chat.completions.acreate( diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c67bb8ce1..c4d48af81 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -11,7 +11,6 @@ import httpx from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from ollama import AsyncClient @@ -19,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -29,8 +29,11 @@ OLLAMA_SUPPORTED_SKUS = { } -class OllamaInferenceAdapter(Inference): +class OllamaInferenceAdapter(Inference, RoutableProviderForModels): def __init__(self, url: str) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS + ) self.url = url tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -72,15 +75,6 @@ class OllamaInferenceAdapter(Inference): return ollama_messages - def resolve_ollama_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}" - - return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True)) - def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -120,7 +114,7 @@ class OllamaInferenceAdapter(Inference): messages = augment_messages_for_tools(request) # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) - ollama_model = self.resolve_ollama_model(request.model) + ollama_model = self.map_to_provider_model(request.model) res = await self.client.ps() need_model_pull = True diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 66f57442f..67cfa21b5 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -13,6 +13,8 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.distribution.datatypes import RoutableProvider + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, @@ -23,7 +25,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl logger = logging.getLogger(__name__) -class _HfAdapter(Inference): +class _HfAdapter(Inference, RoutableProvider): client: AsyncInferenceClient max_tokens: int model_id: str @@ -32,6 +34,14 @@ class _HfAdapter(Inference): self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) + async def register_routing_keys(self, routing_keys: list[str]) -> None: + # these are the model names the Llama Stack will use to route requests to this provider + # perform validation here if necessary + self.routing_keys = routing_keys + + def get_routing_keys(self) -> list[str]: + return self.routing_keys + async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 7053834bd..2c2c0c4d8 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -10,7 +10,6 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from together import Together @@ -19,6 +18,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from .config import TogetherImplConfig @@ -32,8 +32,13 @@ TOGETHER_SUPPORTED_MODELS = { } -class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): +class TogetherInferenceAdapter( + Inference, NeedsRequestProviderData, RoutableProviderForModels +): def __init__(self, config: TogetherImplConfig) -> None: + RoutableProviderForModels.__init__( + self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS + ) self.config = config tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -69,18 +74,6 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): return together_messages - def resolve_together_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in TOGETHER_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(TOGETHER_SUPPORTED_MODELS.keys())}" - - return TOGETHER_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - def get_together_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -125,7 +118,7 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): # accumulate sampling params and other options to pass to together options = self.get_together_chat_options(request) - together_model = self.resolve_together_model(request.model) + together_model = self.map_to_provider_model(request.model) messages = augment_messages_for_tools(request) if not request.stream: diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e9b790dd5..5184b50f0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -6,21 +6,13 @@ import asyncio -from typing import AsyncIterator, Union +from typing import AsyncIterator, List, Union -from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - Inference, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) @@ -28,15 +20,12 @@ from llama_stack.providers.utils.inference.augment_messages import ( from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 - # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference): +class MetaReferenceInferenceImpl(Inference, RoutableProvider): def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config model = resolve_model(config.model) @@ -49,6 +38,15 @@ class MetaReferenceInferenceImpl(Inference): self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() + async def register_routing_keys(self, routing_keys: List[str]) -> None: + assert ( + len(routing_keys) == 1 + ), f"Only one routing key is supported {routing_keys}" + assert routing_keys[0] == self.config.model + + def get_routing_keys(self) -> List[str]: + return [self.config.model] + async def shutdown(self) -> None: self.generator.stop() diff --git a/llama_stack/providers/utils/inference/routable.py b/llama_stack/providers/utils/inference/routable.py new file mode 100644 index 000000000..254e12d60 --- /dev/null +++ b/llama_stack/providers/utils/inference/routable.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_models.sku_list import resolve_model + +from llama_stack.distribution.datatypes import RoutableProvider + + +class RoutableProviderForModels(RoutableProvider): + + def __init__(self, stack_to_provider_models_map: Dict[str, str]): + self.stack_to_provider_models_map = stack_to_provider_models_map + + async def register_routing_keys(self, routing_keys: List[str]): + for routing_key in routing_keys: + if routing_key not in self.stack_to_provider_models_map: + raise ValueError( + f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}" + ) + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + + def map_to_provider_model(self, routing_key: str) -> str: + model = resolve_model(routing_key) + if not model: + raise ValueError(f"Unknown model: `{routing_key}`") + + if routing_key not in self.stack_to_provider_models_map: + raise ValueError( + f"Model {routing_key} not found in map {self.stack_to_provider_models_map}" + ) + + return self.stack_to_provider_models_map[routing_key]