From 3b83032555d83695c01c8d40cd0750401160b03f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Jul 2025 15:22:48 -0700 Subject: [PATCH 1/5] feat(registry): more flexible model lookup (#2859) This PR updates model registration and lookup behavior to be slightly more general / flexible. See https://github.com/meta-llama/llama-stack/issues/2843 for more details. Note that this change is backwards compatible given the design of the `lookup_model()` method. ## Test Plan Added unit tests --- .github/workflows/integration-tests.yml | 2 +- .../workflows/integration-vector-io-tests.yml | 2 +- llama_stack/distribution/routers/datasets.py | 6 +- .../distribution/routers/eval_scoring.py | 21 ++- llama_stack/distribution/routers/inference.py | 17 +- llama_stack/distribution/routers/safety.py | 3 +- .../distribution/routers/tool_runtime.py | 13 +- llama_stack/distribution/routers/vector_io.py | 14 +- .../distribution/routing_tables/common.py | 28 +++- .../distribution/routing_tables/models.py | 36 +++-- .../distribution/routing_tables/toolgroups.py | 6 +- .../distribution/routing_tables/vector_dbs.py | 37 +++-- llama_stack/providers/datatypes.py | 2 +- .../utils/inference/embedding_mixin.py | 2 +- .../routers/test_routing_tables.py | 151 ++++++++++++++++-- 15 files changed, 265 insertions(+), 75 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f8f01756d..082f1e204 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -99,7 +99,7 @@ jobs: uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ - --embedding-model=all-MiniLM-L6-v2 \ + --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ --safety-shield=$SAFETY_MODEL \ --color=yes \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index ec236b33b..525c17d46 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -114,7 +114,7 @@ jobs: run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model all-MiniLM-L6-v2 + --embedding-model sentence-transformers/all-MiniLM-L6-v2 - name: Check Storage and Memory Available After Tests if: ${{ always() }} diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py index 6f28756c9..d7984f729 100644 --- a/llama_stack/distribution/routers/datasets.py +++ b/llama_stack/distribution/routers/datasets.py @@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO): logger.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.iterrows( dataset_id=dataset_id, start_index=start_index, limit=limit, @@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO): async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.append_rows( dataset_id=dataset_id, rows=rows, ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py index fd0bb90a7..f7a17eecf 100644 --- a/llama_stack/distribution/routers/eval_scoring.py +++ b/llama_stack/distribution/routers/eval_scoring.py @@ -44,7 +44,8 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -66,7 +67,8 @@ class ScoringRouter(Scoring): res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -97,7 +99,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> Job: logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, ) @@ -110,7 +113,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, scoring_functions=scoring_functions, @@ -123,7 +127,8 @@ class EvalRouter(Eval): job_id: str, ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_status(benchmark_id, job_id) async def job_cancel( self, @@ -131,7 +136,8 @@ class EvalRouter(Eval): job_id: str, ) -> None: logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + provider = await self.routing_table.get_provider_impl(benchmark_id) + await provider.job_cancel( benchmark_id, job_id, ) @@ -142,7 +148,8 @@ class EvalRouter(Eval): job_id: str, ) -> EvaluateResponse: logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_result( benchmark_id, job_id, ) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index b39da7810..a5cc8c4b5 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -231,7 +231,7 @@ class InferenceRouter(Inference): logprobs=logprobs, tool_config=tool_config, ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: @@ -292,7 +292,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_chat_completion( model_id=model_id, messages_batch=messages_batch, @@ -322,7 +322,7 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, content=content, @@ -378,7 +378,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) async def embeddings( @@ -395,7 +395,8 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") - return await self.routing_table.get_provider_impl(model_id).embeddings( + provider = await self.routing_table.get_provider_impl(model_id) + return await provider.embeddings( model_id=model_id, contents=contents, text_truncation=text_truncation, @@ -458,7 +459,7 @@ class InferenceRouter(Inference): suffix=suffix, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_completion(**params) async def openai_chat_completion( @@ -538,7 +539,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: response_stream = await provider.openai_chat_completion(**params) if self.store: @@ -575,7 +576,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_embeddings(**params) async def list_chat_completions( diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 9761d2db0..26ee8e722 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -50,7 +50,8 @@ class SafetyRouter(Safety): params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( + provider = await self.routing_table.get_provider_impl(shield_id) + return await provider.run_shield( shield_id=shield_id, messages=messages, params=params, diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 285843dbc..5a40bc0c5 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime): query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) + provider = await self.routing_table.get_provider_impl("knowledge_search") + return await provider.query(content, vector_db_ids, query_config) async def insert( self, @@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + provider = await self.routing_table.get_provider_impl("insert_into_memory") + return await provider.insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime): async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + provider = await self.routing_table.get_provider_impl(tool_name) + return await provider.invoke_tool( tool_name=tool_name, kwargs=kwargs, ) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index a1dd66060..3d0996c49 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -104,7 +104,8 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -113,7 +114,8 @@ class VectorIORouter(VectorIO): params: dict[str, Any] | None = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.query_chunks(vector_db_id, query, params) # OpenAI Vector Stores API endpoints async def openai_create_vector_store( @@ -146,7 +148,8 @@ class VectorIORouter(VectorIO): provider_vector_db_id=vector_db_id, vector_db_name=name, ) - return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( + provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) + return await provider.openai_create_vector_store( name=name, file_ids=file_ids, expires_after=expires_after, @@ -172,9 +175,8 @@ class VectorIORouter(VectorIO): all_stores = [] for vector_db in vector_dbs: try: - vector_store = await self.routing_table.get_provider_impl( - vector_db.identifier - ).openai_retrieve_vector_store(vector_db.identifier) + provider = await self.routing_table.get_provider_impl(vector_db.identifier) + vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) all_stores.append(vector_store) except Exception as e: logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index bbe0113e9..2f6ac90bb 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -6,6 +6,7 @@ from typing import Any +from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed @@ -116,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable from .models import ModelsRoutingTable @@ -235,3 +236,28 @@ class CommonRoutingTableImpl(RoutingTable): ] return filtered_objs + + +async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: + # first try to get the model by identifier + # this works if model_id is an alias or is of the form provider_id/provider_model_id + model = await routing_table.get_object_by_identifier("model", model_id) + if model is not None: + return model + + logger.warning( + f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " + "searching in all providers. This is only for backwards compatibility and will stop working " + "soon. Migrate your calls to use fully scoped `provider_id/model_id` names." + ) + # if not found, this means model_id is an unscoped provider_model_id, we need + # to iterate (given a lack of an efficient index on the KVStore) + models = await routing_table.get_all_with_type("model") + matching_models = [m for m in models if m.provider_resource_id == model_id] + if len(matching_models) == 0: + raise ValueError(f"Model '{model_id}' not found") + + if len(matching_models) > 1: + raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") + + return matching_models[0] diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 9a9db7257..f2787b308 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model + return await lookup_model(self, model_id) + + async def get_provider_impl(self, model_id: str) -> Any: + model = await lookup_model(self, model_id) + return self.impls_by_provider_id[model.provider_id] async def register_model( self, @@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: - if provider_model_id is None: - provider_model_id = model_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this model if len(self.impls_by_provider_id) == 1: provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n" + "Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'." ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm + + provider_model_id = provider_model_id or model_id + metadata = metadata or {} + model_type = model_type or ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") + + # an identifier different than provider_model_id implies it is an alias, so that + # becomes the globally unique identifier. otherwise provider_model_ids can conflict, + # so as a general rule we must use the provider_id to disambiguate. + + if model_id != provider_model_id: + identifier = model_id + else: + identifier = f"{provider_id}/{provider_model_id}" + model = ModelWithOwner( - identifier=model_id, + identifier=identifier, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..22c4e109a 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): tool_to_toolgroup: dict[str, str] = {} # overridden - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? @@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): if routing_key in self.tool_to_toolgroup: routing_key = self.tool_to_toolgroup[routing_key] - return super().get_provider_impl(routing_key, provider_id) + return await super().get_provider_impl(routing_key, provider_id) async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: if toolgroup_id: @@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolsResponse(data=all_tools) async def _index_tools(self, toolgroup: ToolGroup): - provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) # TODO: kill this Tool vs ToolDef distinction diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index b4e60c625..58ecf24da 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id + provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): ) else: raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) + model = await lookup_model(self, embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: @@ -93,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -103,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -115,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + result = await provider.openai_delete_vector_store(vector_store_id) await self.unregister_vector_db(vector_store_id) return result @@ -130,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -148,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -165,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -180,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -191,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileContentsResponse: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -203,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): attributes: dict[str, Any], ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -215,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index efe8a98fe..424380324 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -113,7 +113,7 @@ class ProviderSpec(BaseModel): class RoutingTable(Protocol): - def get_provider_impl(self, routing_key: str) -> Any: ... + async def get_provider_impl(self, routing_key: str) -> Any: ... # TODO: this can now be inlined into RemoteProviderSpec diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 97cf87360..32e89f987 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -88,7 +88,7 @@ class SentenceTransformerEmbeddingMixin: usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) return OpenAIEmbeddingsResponse( data=data, - model=model_obj.provider_resource_id, + model=model, usage=usage, ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 30f795d33..12b05ebff 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,15 +11,17 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -104,6 +106,17 @@ class ToolGroupsImpl(Impl): ) +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -115,27 +128,27 @@ async def test_models_routing_table(cached_disk_dist_registry): models = await table.list_models() assert len(models.data) == 2 model_ids = {m.identifier for m in models.data} - assert "test-model" in model_ids - assert "test-model-2" in model_ids + assert "test_provider/test-model" in model_ids + assert "test_provider/test-model-2" in model_ids # Test openai list models openai_models = await table.openai_list_models() assert len(openai_models.data) == 2 openai_model_ids = {m.id for m in openai_models.data} - assert "test-model" in openai_model_ids - assert "test-model-2" in openai_model_ids + assert "test_provider/test-model" in openai_model_ids + assert "test_provider/test-model-2" in openai_model_ids # Test get_object_by_identifier - model = await table.get_object_by_identifier("model", "test-model") + model = await table.get_object_by_identifier("model", "test_provider/test-model") assert model is not None - assert model.identifier == "test-model" + assert model.identifier == "test_provider/test-model" # Test get_object_by_identifier on non-existent object non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None - await table.unregister_model(model_id="test-model") - await table.unregister_model(model_id="test-model-2") + await table.unregister_model(model_id="test_provider/test-model") + await table.unregister_model(model_id="test_provider/test-model-2") models = await table.list_models() assert len(models.data) == 0 @@ -160,6 +173,36 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -245,3 +288,93 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry): await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 + + +async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): + """Test alias registration (model_id != provider_model_id) and lookup behavior.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model with alias (model_id different from provider_model_id) + await table.register_model( + model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" + ) + + # Verify the model was registered with alias as identifier (not namespaced) + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "my-alias" # Uses alias as identifier + assert model.provider_resource_id == "actual-provider-model" + + # Test lookup by alias works + retrieved_model = await table.get_model("my-alias") + assert retrieved_model.identifier == "my-alias" + assert retrieved_model.provider_resource_id == "actual-provider-model" + + +async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): + """Test registration and lookup with multiple providers having same provider_model_id.""" + table = ModelsRoutingTable( + {"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {} + ) + await table.initialize() + + # Register same provider_model_id on both providers (no aliases) + await table.register_model(model_id="common-model", provider_id="provider1") + await table.register_model(model_id="common-model", provider_id="provider2") + + # Verify both models get namespaced identifiers + models = await table.list_models() + assert len(models.data) == 2 + identifiers = {m.identifier for m in models.data} + assert identifiers == {"provider1/common-model", "provider2/common-model"} + + # Test lookup by full namespaced identifier works + model1 = await table.get_model("provider1/common-model") + assert model1.provider_id == "provider1" + assert model1.provider_resource_id == "common-model" + + model2 = await table.get_model("provider2/common-model") + assert model2.provider_id == "provider2" + assert model2.provider_resource_id == "common-model" + + # Test lookup by unscoped provider_model_id fails with multiple providers error + try: + await table.get_model("common-model") + raise AssertionError("Should have raised ValueError for multiple providers") + except ValueError as e: + assert "Multiple providers found" in str(e) + assert "provider1" in str(e) and "provider2" in str(e) + + +async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): + """Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model without alias (gets namespaced identifier) + await table.register_model(model_id="test-model", provider_id="test_provider") + + # Verify namespaced identifier was created + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "test_provider/test-model" + assert model.provider_resource_id == "test-model" + + # Test lookup by full namespaced identifier (direct hit via get_object_by_identifier) + retrieved_model = await table.get_model("test_provider/test-model") + assert retrieved_model.identifier == "test_provider/test-model" + + # Test lookup by unscoped provider_model_id (fallback via iteration) + retrieved_model = await table.get_model("test-model") + assert retrieved_model.identifier == "test_provider/test-model" + assert retrieved_model.provider_resource_id == "test-model" + + # Test lookup of non-existent model fails + try: + await table.get_model("non-existent") + raise AssertionError("Should have raised ValueError for non-existent model") + except ValueError as e: + assert "not found" in str(e) From 340448e0aa0540f824ce0476505594d55f1088e6 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Wed, 23 Jul 2025 00:51:52 +0100 Subject: [PATCH 2/5] fix: optimize container build by enabling uv cache (#2855) - Remove --no-cache flags from uv pip install commands to enable caching - Mount host uv cache directory to container for persistent caching - Set UV_LINK_MODE=copy to prevent uv using hardlinks - When building the starter image o Build time reduced from ~4:45 to ~3:05 on subsequent builds (environment specific) o Eliminates re-downloading of 3G+ of data on each build o Cache size: ~6.2G (when building starter image) Fixes excessive data downloads during distro container builds. Signed-off-by: Derek Higgins --- llama_stack/distribution/build_container.sh | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 6e794b36f..74776dd7d 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -19,6 +19,9 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} # mounting is not supported by docker buildx, so we use COPY instead USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} +# Mount command for cache container .cache, can be overridden by the user if needed +MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"} + # Path to the run.yaml file in the container RUN_CONFIG_PATH=/app/run.yaml @@ -125,11 +128,16 @@ RUN pip install uv EOF fi +# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory +add_to_container << EOF +ENV UV_LINK_MODE=copy +EOF + # Add pip dependencies first since llama-stack is what will change most often # so we can reuse layers. if [ -n "$pip_dependencies" ]; then add_to_container << EOF -RUN uv pip install --no-cache $pip_dependencies +RUN $MOUNT_CACHE uv pip install $pip_dependencies EOF fi @@ -137,7 +145,7 @@ if [ -n "$special_pip_deps" ]; then IFS='#' read -ra parts <<<"$special_pip_deps" for part in "${parts[@]}"; do add_to_container < Date: Wed, 23 Jul 2025 05:48:23 +0200 Subject: [PATCH 3/5] fix: honour deprecation of --config and --template (#2856) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? https://github.com/meta-llama/llama-stack/pull/2716/ broke commands like: ``` python -m llama_stack.distribution.server.server --config llama_stack/templates/starter/run.yaml ``` And will fail with: ``` Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 626, in main() File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 402, in main config_file = resolve_config_or_template(args.config, Mode.RUN) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/utils/config_resolution.py", line 43, in resolve_config_or_template config_path = Path(config_or_template) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Cellar/python@3.12/3.12.8/Frameworks/Python.framework/Versions/3.12/lib/python3.12/pathlib.py", line 1162, in __init__ super().__init__(*args) File "/opt/homebrew/Cellar/python@3.12/3.12.8/Frameworks/Python.framework/Versions/3.12/lib/python3.12/pathlib.py", line 373, in __init__ raise TypeError( TypeError: argument should be a str or an os.PathLike object where __fspath__ returns a str, not 'NoneType' ``` Complaining that no positional arguments are present. We now honour the deprecation until --config and --template are removed completely. ## Test Plan Both ` python -m llama_stack.distribution.server.server --config llama_stack/templates/starter/run.yaml` and ` python -m llama_stack.distribution.server.server llama_stack/templates/starter/run.yaml` should run the server. Same for `--template starter`. Signed-off-by: Sébastien Han --- llama_stack/cli/utils.py | 21 +++++++++++++++++++-- llama_stack/distribution/server/server.py | 5 +++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py index 433627cc0..94cff42e8 100644 --- a/llama_stack/cli/utils.py +++ b/llama_stack/cli/utils.py @@ -6,6 +6,10 @@ import argparse +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="cli") + def add_config_template_args(parser: argparse.ArgumentParser): """Add unified config/template arguments with backward compatibility.""" @@ -20,12 +24,25 @@ def add_config_template_args(parser: argparse.ArgumentParser): # Backward compatibility arguments (deprecated) group.add_argument( "--config", - dest="config", + dest="config_deprecated", help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", ) group.add_argument( "--template", - dest="config", + dest="template_deprecated", help="(DEPRECATED) Use positional argument [config] instead. Template name", ) + + +def get_config_from_args(args: argparse.Namespace) -> str | None: + """Extract config value from parsed arguments, handling both new and deprecated forms.""" + if args.config is not None: + return str(args.config) + elif hasattr(args, "config_deprecated") and args.config_deprecated is not None: + logger.warning("Using deprecated --config argument. Use positional argument [config] instead.") + return str(args.config_deprecated) + elif hasattr(args, "template_deprecated") and args.template_deprecated is not None: + logger.warning("Using deprecated --template argument. Use positional argument [config] instead.") + return str(args.template_deprecated) + return None diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e58c28f2e..f05c4ad83 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,7 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.cli.utils import add_config_template_args +from llama_stack.cli.utils import add_config_template_args, get_config_from_args from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import ( AuthenticationRequiredError, @@ -399,7 +399,8 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - config_file = resolve_config_or_template(args.config, Mode.RUN) + config_or_template = get_config_from_args(args) + config_file = resolve_config_or_template(config_or_template, Mode.RUN) logger_config = None with open(config_file) as fp: From fc67ad408a253e367be22a7c67602799fa70a38b Mon Sep 17 00:00:00 2001 From: grs Date: Wed, 23 Jul 2025 09:27:27 +0100 Subject: [PATCH 4/5] chore: add some documentation for access policy rules (#2785) # What does this PR do? Adds some documentation on setting explicit access_policy rules in config. --- docs/source/distributions/configuration.md | 119 +++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 9548780c6..6362effe8 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -385,6 +385,125 @@ And must respond with: If no access attributes are returned, the token is used as a namespace. +### Access control + +When authentication is enabled, access to resources is controlled +through the `access_policy` attribute of the auth config section under +server. The value for this is a list of access rules. + +Each access rule defines a list of actions either to permit or to +forbid. It may specify a principal or a resource that must match for +the rule to take effect. + +Valid actions are create, read, update, and delete. The resource to +match should be specified in the form of a type qualified identifier, +e.g. model::my-model or vector_db::some-db, or a wildcard for all +resources of a type, e.g. model::*. If the principal or resource are +not specified, they will match all requests. + +The valid resource types are model, shield, vector_db, dataset, +scoring_function, benchmark, tool, tool_group and session. + +A rule may also specify a condition, either a 'when' or an 'unless', +with additional constraints as to where the rule applies. The +constraints supported at present are: + + - 'user with in ' + - 'user with not in ' + - 'user is owner' + - 'user is not owner' + - 'user in owners ' + - 'user not in owners ' + +The attributes defined for a user will depend on how the auth +configuration is defined. + +When checking whether a particular action is allowed by the current +user for a resource, all the defined rules are tested in order to find +a match. If a match is found, the request is permitted or forbidden +depending on the type of rule. If no match is found, the request is +denied. + +If no explicit rules are specified, a default policy is defined with +which all users can access all resources defined in config but +resources created dynamically can only be accessed by the user that +created them. + +Examples: + +The following restricts access to particular github users: + +```yaml +server: + auth: + provider_config: + type: "github_token" + github_api_base_url: "https://api.github.com" + access_policy: + - permit: + principal: user-1 + actions: [create, read, delete] + description: user-1 has full access to all resources + - permit: + principal: user-2 + actions: [read] + resource: model::model-1 + description: user-2 has read access to model-1 only +``` + +Similarly, the following restricts access to particular kubernetes +service accounts: + +```yaml +server: + auth: + provider_config: + type: "oauth2_token" + audience: https://kubernetes.default.svc.cluster.local + issuer: https://kubernetes.default.svc.cluster.local + tls_cafile: /home/gsim/.minikube/ca.crt + jwks: + uri: https://kubernetes.default.svc.cluster.local:8443/openid/v1/jwks + token: ${env.TOKEN} + access_policy: + - permit: + principal: system:serviceaccount:my-namespace:my-serviceaccount + actions: [create, read, delete] + description: specific serviceaccount has full access to all resources + - permit: + principal: system:serviceaccount:default:default + actions: [read] + resource: model::model-1 + description: default account has read access to model-1 only +``` + +The following policy, which assumes that users are defined with roles +and teams by whichever authentication system is in use, allows any +user with a valid token to use models, create resources other than +models, read and delete resources they created and read resources +created by users sharing a team with them: + +``` + access_policy: + - permit: + actions: [read] + resource: model::* + description: all users have read access to models + - forbid: + actions: [create, delete] + resource: model::* + unless: user with admin in roles + description: only user with admin role can create or delete models + - permit: + actions: [create, read, delete] + when: user is owner + description: users can create resources other than models and read and delete those they own + - permit: + actions: [read] + when: user in owner teams + description: any user has read access to any resource created by a user with the same team +``` + ### Quota Configuration The `quota` section allows you to enable server-side request throttling for both From e1ed1527795170c9f14eb43d1ba163926eb148a8 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 23 Jul 2025 06:49:40 -0400 Subject: [PATCH 5/5] chore: create OpenAIMixin for inference providers with an OpenAI-compat API that need to implement openai_* methods (#2835) # What does this PR do? add an `OpenAIMixin` for use by inference providers who remote endpoints support an OpenAI compatible API. use is demonstrated by refactoring - OpenAIInferenceAdapter - NVIDIAInferenceAdapter (adds embedding support) - LlamaCompatInferenceAdapter ## Test Plan existing unit and integration tests --- docs/source/contributing/new_api_provider.md | 35 +++ .../inference/llama_openai_compat/llama.py | 43 ++- .../remote/inference/nvidia/nvidia.py | 191 ++---------- .../remote/inference/openai/openai.py | 223 ++------------ .../utils/inference/model_registry.py | 6 + .../providers/utils/inference/openai_mixin.py | 272 ++++++++++++++++++ .../test_inference_client_caching.py | 19 +- 7 files changed, 402 insertions(+), 387 deletions(-) create mode 100644 llama_stack/providers/utils/inference/openai_mixin.py diff --git a/docs/source/contributing/new_api_provider.md b/docs/source/contributing/new_api_provider.md index 83058896a..01a8ec093 100644 --- a/docs/source/contributing/new_api_provider.md +++ b/docs/source/contributing/new_api_provider.md @@ -14,6 +14,41 @@ Here are some example PRs to help you get started: - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) +## Inference Provider Patterns + +When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers. + +### OpenAIMixin + +The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes: + +#### Direct API Methods +- **`openai_completion()`**: Legacy text completion API with full parameter support +- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling +- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions + +#### Model Management +- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible + +#### Client Management +- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials + +#### Required Implementation + +To use `OpenAIMixin`, your provider must implement these abstract methods: + +```python +@abstractmethod +def get_api_key(self) -> str: + """Return the API key for authentication""" + pass + + +@abstractmethod +def get_base_url(self) -> str: + """Return the OpenAI-compatible API base URL""" + pass +``` ## Testing the Provider diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 5f9cb20b2..576080d99 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -5,17 +5,27 @@ # the root directory of this source tree. import logging -from llama_api_client import AsyncLlamaAPIClient, NotFoundError - from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES logger = logging.getLogger(__name__) -class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): +class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + Llama API Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists + - ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning + """ + _config: LlamaCompatConfig def __init__(self, config: LlamaCompatConfig): @@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config - async def check_model_availability(self, model: str) -> bool: + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: """ - Check if a specific model is available from Llama API. + Get the base URL for OpenAI mixin. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: The Llama API base URL """ - try: - llama_api_client = self._get_llama_api_client() - retrieved_model = await llama_api_client.models.retrieve(model) - logger.info(f"Model {retrieved_model.id} is available from Llama API") - return True - - except NotFoundError: - logger.error(f"Model {model} is not available from Llama API") - return False - - except Exception as e: - logger.error(f"Failed to check model availability from Llama API: {e}") - return False + return self.config.openai_compat_api_base async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - - def _get_llama_api_client(self) -> AsyncLlamaAPIClient: - return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index cb7554523..7bc3fd0c9 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,9 +7,8 @@ import logging import warnings from collections.abc import AsyncIterator -from typing import Any -from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError +from openai import APIConnectionError, BadRequestError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -28,12 +27,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, ResponseFormat, SamplingParams, TextTruncation, @@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, - prepare_openai_completion_params, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from . import NVIDIAConfig @@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): +class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): + """ + NVIDIA Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). It also + must come before Inference to ensure that OpenAIMixin methods are available + in the Inference interface. + + - OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists + - ModelRegistryHelper.check_model_availability() just returns False and shows a warning + """ + def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - async def check_model_availability(self, model: str) -> bool: + def get_api_key(self) -> str: """ - Check if a specific model is available. + Get the API key for OpenAI mixin. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: The NVIDIA API key """ - try: - await self._client.models.retrieve(model) - return True - except NotFoundError: - logger.error(f"Model {model} is not available") - except Exception as e: - logger.error(f"Failed to check model availability: {e}") - return False + return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY" - @property - def _client(self) -> AsyncOpenAI: + def get_base_url(self) -> str: """ - Returns an OpenAI client for the configured NVIDIA API endpoint. + Get the base URL for OpenAI mixin. - :return: An OpenAI client + :return: The NVIDIA API base URL """ - - base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url - - return AsyncOpenAI( - base_url=base_url, - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) - - async def _get_provider_model_id(self, model_id: str) -> str: - if not self.model_store: - raise RuntimeError("Model store is not set") - model = await self.model_store.get_model(model_id) - if model is None: - raise ValueError(f"Model {model_id} is unknown") - return model.provider_model_id + return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url async def completion( self, @@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.completions.create(**request) + response = await self.client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._client.embeddings.create( + response = await self.client.embeddings.create( model=provider_model_id, input=input, extra_body=extra_body, @@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def chat_completion( self, model_id: str, @@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.chat.completions.create(**request) + response = await self.client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - provider_model_id = await self._get_provider_model_id(model) - - params = await prepare_openai_completion_params( - model=provider_model_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - - try: - return await self._client.completions.create(**params) - except APIConnectionError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - provider_model_id = await self._get_provider_model_id(model) - - params = await prepare_openai_completion_params( - model=provider_model_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - try: - return await self._client.chat.completions.create(**params) - except APIConnectionError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 7e167f621..9e1b77bde 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -5,23 +5,9 @@ # the root directory of this source tree. import logging -from collections.abc import AsyncIterator -from typing import Any -from openai import AsyncOpenAI, NotFoundError - -from llama_stack.apis.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, - OpenAIMessageParam, - OpenAIResponseFormatParam, -) from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES @@ -30,7 +16,7 @@ logger = logging.getLogger(__name__) # -# This OpenAI adapter implements Inference methods using two clients - +# This OpenAI adapter implements Inference methods using two mixins - # # | Inference Method | Implementation Source | # |----------------------------|--------------------------| @@ -39,11 +25,22 @@ logger = logging.getLogger(__name__) # | embedding | LiteLLMOpenAIMixin | # | batch_completion | LiteLLMOpenAIMixin | # | batch_chat_completion | LiteLLMOpenAIMixin | -# | openai_completion | AsyncOpenAI | -# | openai_chat_completion | AsyncOpenAI | -# | openai_embeddings | AsyncOpenAI | +# | openai_completion | OpenAIMixin | +# | openai_chat_completion | OpenAIMixin | +# | openai_embeddings | OpenAIMixin | # -class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): +class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + OpenAI Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists + - ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning + """ + def __init__(self, config: OpenAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True - async def check_model_availability(self, model: str) -> bool: + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: """ - Check if a specific model is available from OpenAI. + Get the OpenAI API base URL. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + Returns the standard OpenAI API base URL for direct OpenAI API calls. """ - try: - openai_client = self._get_openai_client() - retrieved_model = await openai_client.models.retrieve(model) - logger.info(f"Model {retrieved_model.id} is available from OpenAI") - return True - - except NotFoundError: - logger.error(f"Model {model} is not available from OpenAI") - return False - - except Exception as e: - logger.error(f"Failed to check model availability from OpenAI: {e}") - return False + return "https://api.openai.com/v1" async def initialize(self) -> None: await super().initialize() async def shutdown(self) -> None: await super().shutdown() - - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI( - api_key=self.get_api_key(), - ) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - if guided_choice is not None: - logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") - if prompt_logprobs is not None: - logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") - - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - params = await prepare_openai_completion_params( - model=model_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - suffix=suffix, - ) - return await self._get_openai_client().completions.create(**params) - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - params = await prepare_openai_completion_params( - model=model_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().chat.completions.create(**params) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - - # Prepare parameters for OpenAI embeddings API - params = { - "model": model_id, - "input": input, - } - - if encoding_format is not None: - params["encoding_format"] = encoding_format - if dimensions is not None: - params["dimensions"] = dimensions - if user is not None: - params["user"] = user - - # Call OpenAI embeddings API - response = await self._get_openai_client().embeddings.create(**params) - - data = [] - for i, embedding_data in enumerate(response.data): - data.append( - OpenAIEmbeddingData( - embedding=embedding_data.embedding, - index=i, - ) - ) - - usage = OpenAIEmbeddingUsage( - prompt_tokens=response.usage.prompt_tokens, - total_tokens=response.usage.total_tokens, - ) - - return OpenAIEmbeddingsResponse( - data=data, - model=response.model, - usage=usage, - ) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 801b8ea06..651d58e2a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -10,12 +10,15 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) +logger = get_logger(name=__name__, category="core") + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. @@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate): :param model: The model identifier to check. :return: True if the model is available dynamically, False otherwise. """ + logger.info( + f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default." + ) return False async def register_model(self, model: Model) -> Model: diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py new file mode 100644 index 000000000..72286dffb --- /dev/null +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any + +import openai +from openai import NOT_GIVEN, AsyncOpenAI + +from llama_stack.apis.inference import ( + Model, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) +from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + +logger = get_logger(name=__name__, category="core") + + +class OpenAIMixin(ABC): + """ + Mixin class that provides OpenAI-specific functionality for inference providers. + This class handles direct OpenAI API calls using the AsyncOpenAI client. + + This is an abstract base class that requires child classes to implement: + - get_api_key(): Method to retrieve the API key + - get_base_url(): Method to retrieve the OpenAI-compatible API base URL + + Expected Dependencies: + - self.model_store: Injected by the Llama Stack distribution system at runtime. + This provides model registry functionality for looking up registered models. + The model_store is set in routing_tables/common.py during provider initialization. + """ + + @abstractmethod + def get_api_key(self) -> str: + """ + Get the API key. + + This method must be implemented by child classes to provide the API key + for authenticating with the OpenAI API or compatible endpoints. + + :return: The API key as a string + """ + pass + + @abstractmethod + def get_base_url(self) -> str: + """ + Get the OpenAI-compatible API base URL. + + This method must be implemented by child classes to provide the base URL + for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1"). + + :return: The base URL as a string + """ + pass + + @property + def client(self) -> AsyncOpenAI: + """ + Get an AsyncOpenAI client instance. + + Uses the abstract methods get_api_key() and get_base_url() which must be + implemented by child classes. + """ + return AsyncOpenAI( + api_key=self.get_api_key(), + base_url=self.get_base_url(), + ) + + async def _get_provider_model_id(self, model: str) -> str: + """ + Get the provider-specific model ID from the model store. + + This is a utility method that looks up the registered model and returns + the provider_resource_id that should be used for actual API calls. + + :param model: The registered model name/identifier + :return: The provider-specific model ID (e.g., "gpt-4") + """ + # Look up the registered model to get the provider-specific model ID + # self.model_store is injected by the distribution system at runtime + model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined] + # provider_resource_id is str | None, but we expect it to be str for OpenAI calls + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {model} has no provider_resource_id") + return model_obj.provider_resource_id + + async def openai_completion( + self, + model: str, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, + suffix: str | None = None, + ) -> OpenAICompletion: + """ + Direct OpenAI completion API call. + """ + if guided_choice is not None: + logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") + if prompt_logprobs is not None: + logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + + # TODO: fix openai_completion to return type compatible with OpenAI's API response + return await self.client.completions.create( # type: ignore[no-any-return] + **await prepare_openai_completion_params( + model=await self._get_provider_model_id(model), + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + suffix=suffix, + ) + ) + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Direct OpenAI chat completion API call. + """ + # Type ignore because return types are compatible + return await self.client.chat.completions.create( # type: ignore[no-any-return] + **await prepare_openai_completion_params( + model=await self._get_provider_model_id(model), + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + ) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """ + Direct OpenAI embeddings API call. + """ + # Call OpenAI embeddings API with properly typed parameters + response = await self.client.embeddings.create( + model=await self._get_provider_model_id(model), + input=input, + encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + dimensions=dimensions if dimensions is not None else NOT_GIVEN, + user=user if user is not None else NOT_GIVEN, + ) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) + + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from OpenAI. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + # Direct model lookup - returns model or raises NotFoundError + await self.client.models.retrieve(model) + return True + except openai.NotFoundError: + # Model doesn't exist - this is expected for unavailable models + pass + except Exception as e: + # All other errors (auth, rate limit, network, etc.) + logger.warning(f"Failed to check model availability for {model}: {e}") + + return False diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index c9a931d47..ba36a3e3d 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -10,6 +10,8 @@ from unittest.mock import MagicMock from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig +from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig @@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching(): with request_provider_data_context( {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): - openai_client = inference_adapter._get_openai_client() + openai_client = inference_adapter.client assert openai_client.api_key == api_key @@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching(): assert together_client.client.api_key == api_key openai_client = inference_adapter._get_openai_client() assert openai_client.api_key == api_key + + +def test_llama_compat_provider_openai_client_caching(): + """Ensure the LlamaCompat provider does not cache api keys across client requests""" + config = LlamaCompatConfig() + inference_adapter = LlamaCompatInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}): + assert inference_adapter.client.api_key == api_key