Merge branch 'main' into migrate-vector-store-helpers

This commit is contained in:
Francisco Arceo 2025-07-23 09:57:33 -04:00 committed by GitHub
commit 07ae065aeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 822 additions and 472 deletions

View file

@ -99,7 +99,7 @@ jobs:
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--safety-shield=$SAFETY_MODEL \
--color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -114,7 +114,7 @@ jobs:
run: |
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \
--embedding-model all-MiniLM-L6-v2
--embedding-model sentence-transformers/all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}

View file

@ -14,6 +14,41 @@ Here are some example PRs to help you get started:
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
## Inference Provider Patterns
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
### OpenAIMixin
The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes:
#### Direct API Methods
- **`openai_completion()`**: Legacy text completion API with full parameter support
- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling
- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions
#### Model Management
- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible
#### Client Management
- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials
#### Required Implementation
To use `OpenAIMixin`, your provider must implement these abstract methods:
```python
@abstractmethod
def get_api_key(self) -> str:
"""Return the API key for authentication"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""Return the OpenAI-compatible API base URL"""
pass
```
## Testing the Provider

View file

@ -385,6 +385,125 @@ And must respond with:
If no access attributes are returned, the token is used as a namespace.
### Access control
When authentication is enabled, access to resources is controlled
through the `access_policy` attribute of the auth config section under
server. The value for this is a list of access rules.
Each access rule defines a list of actions either to permit or to
forbid. It may specify a principal or a resource that must match for
the rule to take effect.
Valid actions are create, read, update, and delete. The resource to
match should be specified in the form of a type qualified identifier,
e.g. model::my-model or vector_db::some-db, or a wildcard for all
resources of a type, e.g. model::*. If the principal or resource are
not specified, they will match all requests.
The valid resource types are model, shield, vector_db, dataset,
scoring_function, benchmark, tool, tool_group and session.
A rule may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies. The
constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
The attributes defined for a user will depend on how the auth
configuration is defined.
When checking whether a particular action is allowed by the current
user for a resource, all the defined rules are tested in order to find
a match. If a match is found, the request is permitted or forbidden
depending on the type of rule. If no match is found, the request is
denied.
If no explicit rules are specified, a default policy is defined with
which all users can access all resources defined in config but
resources created dynamically can only be accessed by the user that
created them.
Examples:
The following restricts access to particular github users:
```yaml
server:
auth:
provider_config:
type: "github_token"
github_api_base_url: "https://api.github.com"
access_policy:
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all resources
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
```
Similarly, the following restricts access to particular kubernetes
service accounts:
```yaml
server:
auth:
provider_config:
type: "oauth2_token"
audience: https://kubernetes.default.svc.cluster.local
issuer: https://kubernetes.default.svc.cluster.local
tls_cafile: /home/gsim/.minikube/ca.crt
jwks:
uri: https://kubernetes.default.svc.cluster.local:8443/openid/v1/jwks
token: ${env.TOKEN}
access_policy:
- permit:
principal: system:serviceaccount:my-namespace:my-serviceaccount
actions: [create, read, delete]
description: specific serviceaccount has full access to all resources
- permit:
principal: system:serviceaccount:default:default
actions: [read]
resource: model::model-1
description: default account has read access to model-1 only
```
The following policy, which assumes that users are defined with roles
and teams by whichever authentication system is in use, allows any
user with a valid token to use models, create resources other than
models, read and delete resources they created and read resources
created by users sharing a team with them:
```
access_policy:
- permit:
actions: [read]
resource: model::*
description: all users have read access to models
- forbid:
actions: [create, delete]
resource: model::*
unless: user with admin in roles
description: only user with admin role can create or delete models
- permit:
actions: [create, read, delete]
when: user is owner
description: users can create resources other than models and read and delete those they own
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a user with the same team
```
### Quota Configuration
The `quota` section allows you to enable server-side request throttling for both

View file

@ -6,6 +6,10 @@
import argparse
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli")
def add_config_template_args(parser: argparse.ArgumentParser):
"""Add unified config/template arguments with backward compatibility."""
@ -20,12 +24,25 @@ def add_config_template_args(parser: argparse.ArgumentParser):
# Backward compatibility arguments (deprecated)
group.add_argument(
"--config",
dest="config",
dest="config_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
)
group.add_argument(
"--template",
dest="config",
dest="template_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Template name",
)
def get_config_from_args(args: argparse.Namespace) -> str | None:
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
if args.config is not None:
return str(args.config)
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
return str(args.config_deprecated)
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
return str(args.template_deprecated)
return None

View file

@ -19,6 +19,9 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Mount command for cache container .cache, can be overridden by the user if needed
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
# Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml
@ -125,11 +128,16 @@ RUN pip install uv
EOF
fi
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
add_to_container << EOF
ENV UV_LINK_MODE=copy
EOF
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
if [ -n "$pip_dependencies" ]; then
add_to_container << EOF
RUN uv pip install --no-cache $pip_dependencies
RUN $MOUNT_CACHE uv pip install $pip_dependencies
EOF
fi
@ -137,7 +145,7 @@ if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
add_to_container <<EOF
RUN uv pip install --no-cache $part
RUN $MOUNT_CACHE uv pip install $part
EOF
done
fi
@ -207,7 +215,7 @@ COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $mount_point
RUN $MOUNT_CACHE uv pip install -e $mount_point
EOF
}
@ -222,10 +230,10 @@ else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
add_to_container << EOF
RUN uv pip install fastapi libcst
RUN $MOUNT_CACHE uv pip install fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION
@ -237,7 +245,7 @@ EOF
SPEC_VERSION="llama-stack"
fi
add_to_container << EOF
RUN uv pip install --no-cache $SPEC_VERSION
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
EOF
fi
fi

View file

@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -97,7 +99,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
@ -110,7 +113,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
@ -123,7 +127,8 @@ class EvalRouter(Eval):
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel(
self,
@ -131,7 +136,8 @@ class EvalRouter(Eval):
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
job_id,
)
@ -142,7 +148,8 @@ class EvalRouter(Eval):
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,
job_id,
)

View file

@ -231,7 +231,7 @@ class InferenceRouter(Inference):
logprobs=logprobs,
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings(
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
suffix=suffix,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if self.store:
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
async def list_chat_completions(

View file

@ -50,7 +50,8 @@ class SafetyRouter(Safety):
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,

View file

@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
async def insert(
self,
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = []
for vector_db in vector_dbs:
try:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")

View file

@ -6,6 +6,7 @@
from typing import Any
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
@ -116,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
@ -235,3 +236,28 @@ class CommonRoutingTableImpl(RoutingTable):
]
return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]

View file

@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
return self.impls_by_provider_id[model.provider_id]
async def register_model(
self,
@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=model_id,
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,

View file

@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction

View file

@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
model = await lookup_model(self, embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
@ -93,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
@ -103,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -115,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
@ -130,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -148,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -165,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -180,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -191,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -203,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -215,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -32,7 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
@ -399,7 +399,8 @@ def main(args: argparse.Namespace | None = None):
if args is None:
args = parser.parse_args()
config_file = resolve_config_or_template(args.config, Mode.RUN)
config_or_template = get_config_from_args(args)
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
logger_config = None
with open(config_file) as fp:

View file

@ -113,7 +113,7 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec

View file

@ -5,17 +5,27 @@
# the root directory of this source tree.
import logging
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from Llama API.
Get the base URL for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The Llama API base URL
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,9 +7,8 @@
import logging
import warnings
from collections.abc import AsyncIterator
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
from openai import APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
async def check_model_availability(self, model: str) -> bool:
def get_api_key(self) -> str:
"""
Check if a specific model is available.
Get the API key for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The NVIDIA API key
"""
try:
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
@property
def _client(self) -> AsyncOpenAI:
def get_base_url(self) -> str:
"""
Returns an OpenAI client for the configured NVIDIA API endpoint.
Get the base URL for OpenAI mixin.
:return: An OpenAI client
:return: The NVIDIA API base URL
"""
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def completion(
self,
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.completions.create(**request)
response = await self.client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._client.embeddings.create(
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.chat.completions.create(**request)
response = await self.client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -5,23 +5,9 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
# This OpenAI adapter implements Inference methods using two mixins -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
@ -39,11 +25,22 @@ logger = logging.getLogger(__name__)
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
# | openai_embeddings | AsyncOpenAI |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from OpenAI.
Get the OpenAI API base URL.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
Returns the standard OpenAI API base URL for direct OpenAI API calls.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
return "https://api.openai.com/v1"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -88,7 +88,7 @@ class SentenceTransformerEmbeddingMixin:
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
model=model,
usage=usage,
)

View file

@ -10,12 +10,15 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
logger.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False
async def register_model(self, model: Model) -> Model:

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
import openai
from openai import NOT_GIVEN, AsyncOpenAI
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
@abstractmethod
def get_api_key(self) -> str:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""
Get the OpenAI-compatible API base URL.
This method must be implemented by child classes to provide the base URL
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
:return: The base URL as a string
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
Get an AsyncOpenAI client instance.
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
"""
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
)
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
This is a utility method that looks up the registered model and returns
the provider_resource_id that should be used for actual API calls.
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
# Direct model lookup - returns model or raises NotFoundError
await self.client.models.retrieve(model)
return True
except openai.NotFoundError:
# Model doesn't exist - this is expected for unavailable models
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
return False

View file

@ -11,15 +11,17 @@ from unittest.mock import AsyncMock
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -104,6 +106,17 @@ class ToolGroupsImpl(Impl):
)
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -115,27 +128,27 @@ async def test_models_routing_table(cached_disk_dist_registry):
models = await table.list_models()
assert len(models.data) == 2
model_ids = {m.identifier for m in models.data}
assert "test-model" in model_ids
assert "test-model-2" in model_ids
assert "test_provider/test-model" in model_ids
assert "test_provider/test-model-2" in model_ids
# Test openai list models
openai_models = await table.openai_list_models()
assert len(openai_models.data) == 2
openai_model_ids = {m.id for m in openai_models.data}
assert "test-model" in openai_model_ids
assert "test-model-2" in openai_model_ids
assert "test_provider/test-model" in openai_model_ids
assert "test_provider/test-model-2" in openai_model_ids
# Test get_object_by_identifier
model = await table.get_object_by_identifier("model", "test-model")
model = await table.get_object_by_identifier("model", "test_provider/test-model")
assert model is not None
assert model.identifier == "test-model"
assert model.identifier == "test_provider/test-model"
# Test get_object_by_identifier on non-existent object
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
await table.unregister_model(model_id="test-model")
await table.unregister_model(model_id="test-model-2")
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")
models = await table.list_models()
assert len(models.data) == 0
@ -160,6 +173,36 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -245,3 +288,93 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry):
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 0
async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
"""Test alias registration (model_id != provider_model_id) and lookup behavior."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model with alias (model_id different from provider_model_id)
await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
)
# Verify the model was registered with alias as identifier (not namespaced)
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier
assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works
retrieved_model = await table.get_model("my-alias")
assert retrieved_model.identifier == "my-alias"
assert retrieved_model.provider_resource_id == "actual-provider-model"
async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
"""Test registration and lookup with multiple providers having same provider_model_id."""
table = ModelsRoutingTable(
{"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {}
)
await table.initialize()
# Register same provider_model_id on both providers (no aliases)
await table.register_model(model_id="common-model", provider_id="provider1")
await table.register_model(model_id="common-model", provider_id="provider2")
# Verify both models get namespaced identifiers
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert identifiers == {"provider1/common-model", "provider2/common-model"}
# Test lookup by full namespaced identifier works
model1 = await table.get_model("provider1/common-model")
assert model1.provider_id == "provider1"
assert model1.provider_resource_id == "common-model"
model2 = await table.get_model("provider2/common-model")
assert model2.provider_id == "provider2"
assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error
try:
await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
"""Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model without alias (gets namespaced identifier)
await table.register_model(model_id="test-model", provider_id="test_provider")
# Verify namespaced identifier was created
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "test_provider/test-model"
assert model.provider_resource_id == "test-model"
# Test lookup by full namespaced identifier (direct hit via get_object_by_identifier)
retrieved_model = await table.get_model("test_provider/test-model")
assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model")
assert retrieved_model.identifier == "test_provider/test-model"
assert retrieved_model.provider_resource_id == "test-model"
# Test lookup of non-existent model fails
try:
await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)

View file

@ -10,6 +10,8 @@ from unittest.mock import MagicMock
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching():
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
openai_client = inference_adapter.client
assert openai_client.api_key == api_key
@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching():
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key