Merge branch 'main' into migrate-vector-store-helpers

This commit is contained in:
Francisco Arceo 2025-07-23 09:57:33 -04:00 committed by GitHub
commit 07ae065aeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 822 additions and 472 deletions

View file

@ -99,7 +99,7 @@ jobs:
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--safety-shield=$SAFETY_MODEL \ --safety-shield=$SAFETY_MODEL \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -114,7 +114,7 @@ jobs:
run: | run: |
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \ tests/integration/vector_io \
--embedding-model all-MiniLM-L6-v2 --embedding-model sentence-transformers/all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests - name: Check Storage and Memory Available After Tests
if: ${{ always() }} if: ${{ always() }}

View file

@ -14,6 +14,41 @@ Here are some example PRs to help you get started:
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
## Inference Provider Patterns
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
### OpenAIMixin
The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes:
#### Direct API Methods
- **`openai_completion()`**: Legacy text completion API with full parameter support
- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling
- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions
#### Model Management
- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible
#### Client Management
- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials
#### Required Implementation
To use `OpenAIMixin`, your provider must implement these abstract methods:
```python
@abstractmethod
def get_api_key(self) -> str:
"""Return the API key for authentication"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""Return the OpenAI-compatible API base URL"""
pass
```
## Testing the Provider ## Testing the Provider

View file

@ -385,6 +385,125 @@ And must respond with:
If no access attributes are returned, the token is used as a namespace. If no access attributes are returned, the token is used as a namespace.
### Access control
When authentication is enabled, access to resources is controlled
through the `access_policy` attribute of the auth config section under
server. The value for this is a list of access rules.
Each access rule defines a list of actions either to permit or to
forbid. It may specify a principal or a resource that must match for
the rule to take effect.
Valid actions are create, read, update, and delete. The resource to
match should be specified in the form of a type qualified identifier,
e.g. model::my-model or vector_db::some-db, or a wildcard for all
resources of a type, e.g. model::*. If the principal or resource are
not specified, they will match all requests.
The valid resource types are model, shield, vector_db, dataset,
scoring_function, benchmark, tool, tool_group and session.
A rule may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies. The
constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
The attributes defined for a user will depend on how the auth
configuration is defined.
When checking whether a particular action is allowed by the current
user for a resource, all the defined rules are tested in order to find
a match. If a match is found, the request is permitted or forbidden
depending on the type of rule. If no match is found, the request is
denied.
If no explicit rules are specified, a default policy is defined with
which all users can access all resources defined in config but
resources created dynamically can only be accessed by the user that
created them.
Examples:
The following restricts access to particular github users:
```yaml
server:
auth:
provider_config:
type: "github_token"
github_api_base_url: "https://api.github.com"
access_policy:
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all resources
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
```
Similarly, the following restricts access to particular kubernetes
service accounts:
```yaml
server:
auth:
provider_config:
type: "oauth2_token"
audience: https://kubernetes.default.svc.cluster.local
issuer: https://kubernetes.default.svc.cluster.local
tls_cafile: /home/gsim/.minikube/ca.crt
jwks:
uri: https://kubernetes.default.svc.cluster.local:8443/openid/v1/jwks
token: ${env.TOKEN}
access_policy:
- permit:
principal: system:serviceaccount:my-namespace:my-serviceaccount
actions: [create, read, delete]
description: specific serviceaccount has full access to all resources
- permit:
principal: system:serviceaccount:default:default
actions: [read]
resource: model::model-1
description: default account has read access to model-1 only
```
The following policy, which assumes that users are defined with roles
and teams by whichever authentication system is in use, allows any
user with a valid token to use models, create resources other than
models, read and delete resources they created and read resources
created by users sharing a team with them:
```
access_policy:
- permit:
actions: [read]
resource: model::*
description: all users have read access to models
- forbid:
actions: [create, delete]
resource: model::*
unless: user with admin in roles
description: only user with admin role can create or delete models
- permit:
actions: [create, read, delete]
when: user is owner
description: users can create resources other than models and read and delete those they own
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a user with the same team
```
### Quota Configuration ### Quota Configuration
The `quota` section allows you to enable server-side request throttling for both The `quota` section allows you to enable server-side request throttling for both

View file

@ -6,6 +6,10 @@
import argparse import argparse
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli")
def add_config_template_args(parser: argparse.ArgumentParser): def add_config_template_args(parser: argparse.ArgumentParser):
"""Add unified config/template arguments with backward compatibility.""" """Add unified config/template arguments with backward compatibility."""
@ -20,12 +24,25 @@ def add_config_template_args(parser: argparse.ArgumentParser):
# Backward compatibility arguments (deprecated) # Backward compatibility arguments (deprecated)
group.add_argument( group.add_argument(
"--config", "--config",
dest="config", dest="config_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
) )
group.add_argument( group.add_argument(
"--template", "--template",
dest="config", dest="template_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Template name", help="(DEPRECATED) Use positional argument [config] instead. Template name",
) )
def get_config_from_args(args: argparse.Namespace) -> str | None:
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
if args.config is not None:
return str(args.config)
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
return str(args.config_deprecated)
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
return str(args.template_deprecated)
return None

View file

@ -19,6 +19,9 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Mount command for cache container .cache, can be overridden by the user if needed
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
# Path to the run.yaml file in the container # Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml RUN_CONFIG_PATH=/app/run.yaml
@ -125,11 +128,16 @@ RUN pip install uv
EOF EOF
fi fi
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
add_to_container << EOF
ENV UV_LINK_MODE=copy
EOF
# Add pip dependencies first since llama-stack is what will change most often # Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers. # so we can reuse layers.
if [ -n "$pip_dependencies" ]; then if [ -n "$pip_dependencies" ]; then
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache $pip_dependencies RUN $MOUNT_CACHE uv pip install $pip_dependencies
EOF EOF
fi fi
@ -137,7 +145,7 @@ if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
add_to_container <<EOF add_to_container <<EOF
RUN uv pip install --no-cache $part RUN $MOUNT_CACHE uv pip install $part
EOF EOF
done done
fi fi
@ -207,7 +215,7 @@ COPY $dir $mount_point
EOF EOF
fi fi
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache -e $mount_point RUN $MOUNT_CACHE uv pip install -e $mount_point
EOF EOF
} }
@ -222,10 +230,10 @@ else
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
add_to_container << EOF add_to_container << EOF
RUN uv pip install fastapi libcst RUN $MOUNT_CACHE uv pip install fastapi libcst
EOF EOF
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \ RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
@ -237,7 +245,7 @@ EOF
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache $SPEC_VERSION RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
EOF EOF
fi fi
fi fi

View file

@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
logger.debug( logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
return await self.routing_table.get_provider_impl(dataset_id).iterrows( provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
start_index=start_index, start_index=start_index,
limit=limit, limit=limit,
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
) )

View file

@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score( provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -97,7 +99,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}") logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
) )
@ -110,7 +113,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
@ -123,7 +127,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
self, self,
@ -131,7 +136,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> None: ) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
) )
@ -142,7 +148,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id, benchmark_id,
job_id, job_id,
) )

View file

@ -231,7 +231,7 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion( return await provider.batch_chat_completion(
model_id=model_id, model_id=model_id,
messages_batch=messages_batch, messages_batch=messages_batch,
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
content=content, content=content,
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings( async def embeddings(
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings( provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
text_truncation=text_truncation, text_truncation=text_truncation,
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
suffix=suffix, suffix=suffix,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params) return await provider.openai_completion(**params)
async def openai_chat_completion( async def openai_chat_completion(
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
user=user, user=user,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if stream:
response_stream = await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(**params)
if self.store: if self.store:
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
user=user, user=user,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params) return await provider.openai_embeddings(**params)
async def list_chat_completions( async def list_chat_completions(

View file

@ -50,7 +50,8 @@ class SafetyRouter(Safety):
params: dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
params=params, params=params,

View file

@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
query_config: RAGQueryConfig | None = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( provider = await self.routing_table.get_provider_impl("knowledge_search")
content, vector_db_ids, query_config return await provider.query(content, vector_db_ids, query_config)
)
async def insert( async def insert(
self, self,
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl("insert_into_memory").insert( provider = await self.routing_table.get_provider_impl("insert_into_memory")
documents, vector_db_id, chunk_size_in_tokens return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
)
def __init__( def __init__(
self, self,
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
) )

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
self, self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints
async def openai_create_vector_store( async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id, provider_vector_db_id=vector_db_id,
vector_db_name=name, vector_db_name=name,
) )
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name, name=name,
file_ids=file_ids, file_ids=file_ids,
expires_after=expires_after, expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = [] all_stores = []
for vector_db in vector_dbs: for vector_db in vector_dbs:
try: try:
vector_store = await self.routing_table.get_provider_impl( provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_db.identifier vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
).openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store) all_stores.append(vector_store)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")

View file

@ -6,6 +6,7 @@
from typing import Any from typing import Any
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
@ -116,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable from .models import ModelsRoutingTable
@ -235,3 +236,28 @@ class CommonRoutingTableImpl(RoutingTable):
] ]
return filtered_objs return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]

View file

@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return OpenAIListModelsResponse(data=openai_models) return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model: async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id) return await lookup_model(self, model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found") async def get_provider_impl(self, model_id: str) -> Any:
return model model = await lookup_model(self, model_id)
return self.impls_by_provider_id[model.provider_id]
async def register_model( async def register_model(
self, self,
@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> Model: ) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model # If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1: if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
else: else:
raise ValueError( raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
) )
if metadata is None:
metadata = {} provider_model_id = provider_model_id or model_id
if model_type is None: metadata = metadata or {}
model_type = ModelType.llm model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner( model = ModelWithOwner(
identifier=model_id, identifier=identifier,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,

View file

@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_to_toolgroup: dict[str, str] = {} tool_to_toolgroup: dict[str, str] = {}
# overridden # overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
if routing_key in self.tool_to_toolgroup: if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key] routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id) return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id: if toolgroup_id:
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolsResponse(data=all_tools) return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup): async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction # TODO: kill this Tool vs ToolDef distinction

View file

@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
vector_db_name: str | None = None, vector_db_name: str | None = None,
) -> VectorDB: ) -> VectorDB:
if provider_vector_db_id is None: provider_vector_db_id = provider_vector_db_id or vector_db_id
provider_vector_db_id = vector_db_id
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
) )
else: else:
raise ValueError("No provider available. Please configure a vector_io provider.") raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model) model = await lookup_model(self, embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
@ -93,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store( async def openai_update_vector_store(
self, self,
@ -103,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
name=name, name=name,
expires_after=expires_after, expires_after=expires_after,
@ -115,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreDeleteResponse: ) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_db", vector_store_id)
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id) await self.unregister_vector_db(vector_store_id)
return result return result
@ -130,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_search_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
query=query, query=query,
filters=filters, filters=filters,
@ -148,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -165,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
limit=limit, limit=limit,
order=order, order=order,
@ -180,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -191,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -203,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -215,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )

View file

@ -32,7 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args from llama_stack.cli.utils import add_config_template_args, get_config_from_args
from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
@ -399,7 +399,8 @@ def main(args: argparse.Namespace | None = None):
if args is None: if args is None:
args = parser.parse_args() args = parser.parse_args()
config_file = resolve_config_or_template(args.config, Mode.RUN) config_or_template = get_config_from_args(args)
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:

View file

@ -113,7 +113,7 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec # TODO: this can now be inlined into RemoteProviderSpec

View file

@ -5,17 +5,27 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig _config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig): def __init__(self, config: LlamaCompatConfig):
@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
) )
self.config = config self.config = config
async def check_model_availability(self, model: str) -> bool: # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
""" """
Check if a specific model is available from Llama API. Get the base URL for OpenAI mixin.
:param model: The model identifier to check. :return: The Llama API base URL
:return: True if the model is available dynamically, False otherwise.
""" """
try: return self.config.openai_compat_api_base
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,9 +7,8 @@
import logging import logging
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError from openai import APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
prepare_openai_completion_params,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig from . import NVIDIAConfig
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None: def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config self._config = config
async def check_model_availability(self, model: str) -> bool: def get_api_key(self) -> str:
""" """
Check if a specific model is available. Get the API key for OpenAI mixin.
:param model: The model identifier to check. :return: The NVIDIA API key
:return: True if the model is available dynamically, False otherwise.
""" """
try: return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
@property def get_base_url(self) -> str:
def _client(self) -> AsyncOpenAI:
""" """
Returns an OpenAI client for the configured NVIDIA API endpoint. Get the base URL for OpenAI mixin.
:return: An OpenAI client :return: The NVIDIA API base URL
""" """
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
async def completion( async def completion(
self, self,
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.completions.create(**request) response = await self.client.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try: try:
response = await self._client.embeddings.create( response = await self.client.embeddings.create(
model=provider_model_id, model=provider_model_id,
input=input, input=input,
extra_body=extra_body, extra_body=extra_body,
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# #
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.chat.completions.create(**request) response = await self.client.chat.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else: else:
# we pass n=1 to get only one completion # we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0]) return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -5,23 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig from .config import OpenAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
# #
# This OpenAI adapter implements Inference methods using two clients - # This OpenAI adapter implements Inference methods using two mixins -
# #
# | Inference Method | Implementation Source | # | Inference Method | Implementation Source |
# |----------------------------|--------------------------| # |----------------------------|--------------------------|
@ -39,11 +25,22 @@ logger = logging.getLogger(__name__)
# | embedding | LiteLLMOpenAIMixin | # | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin | # | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin | # | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI | # | openai_completion | OpenAIMixin |
# | openai_chat_completion | AsyncOpenAI | # | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | AsyncOpenAI | # | openai_embeddings | OpenAIMixin |
# #
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None: def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak. # litellm specific model names, an abstraction leak.
self.is_openai_compat = True self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool: # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
""" """
Check if a specific model is available from OpenAI. Get the OpenAI API base URL.
:param model: The model identifier to check. Returns the standard OpenAI API base URL for direct OpenAI API calls.
:return: True if the model is available dynamically, False otherwise.
""" """
try: return "https://api.openai.com/v1"
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await super().shutdown() await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -88,7 +88,7 @@ class SentenceTransformerEmbeddingMixin:
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse( return OpenAIEmbeddingsResponse(
data=data, data=data,
model=model_obj.provider_resource_id, model=model,
usage=usage, usage=usage,
) )

View file

@ -10,12 +10,15 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
) )
logger = get_logger(name=__name__, category="core")
# TODO: this class is more confusing than useful right now. We need to make it # TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class. # more closer to the Model class.
@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check. :param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise. :return: True if the model is available dynamically, False otherwise.
""" """
logger.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False return False
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
import openai
from openai import NOT_GIVEN, AsyncOpenAI
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
@abstractmethod
def get_api_key(self) -> str:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""
Get the OpenAI-compatible API base URL.
This method must be implemented by child classes to provide the base URL
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
:return: The base URL as a string
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
Get an AsyncOpenAI client instance.
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
"""
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
)
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
This is a utility method that looks up the registered model and returns
the provider_resource_id that should be used for actual API calls.
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
# Direct model lookup - returns model or raises NotFoundError
await self.client.models.retrieve(model)
return True
except openai.NotFoundError:
# Model doesn't exist - this is expected for unavailable models
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
return False

View file

@ -11,15 +11,17 @@ from unittest.mock import AsyncMock
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl: class Impl:
@ -104,6 +106,17 @@ class ToolGroupsImpl(Impl):
) )
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -115,27 +128,27 @@ async def test_models_routing_table(cached_disk_dist_registry):
models = await table.list_models() models = await table.list_models()
assert len(models.data) == 2 assert len(models.data) == 2
model_ids = {m.identifier for m in models.data} model_ids = {m.identifier for m in models.data}
assert "test-model" in model_ids assert "test_provider/test-model" in model_ids
assert "test-model-2" in model_ids assert "test_provider/test-model-2" in model_ids
# Test openai list models # Test openai list models
openai_models = await table.openai_list_models() openai_models = await table.openai_list_models()
assert len(openai_models.data) == 2 assert len(openai_models.data) == 2
openai_model_ids = {m.id for m in openai_models.data} openai_model_ids = {m.id for m in openai_models.data}
assert "test-model" in openai_model_ids assert "test_provider/test-model" in openai_model_ids
assert "test-model-2" in openai_model_ids assert "test_provider/test-model-2" in openai_model_ids
# Test get_object_by_identifier # Test get_object_by_identifier
model = await table.get_object_by_identifier("model", "test-model") model = await table.get_object_by_identifier("model", "test_provider/test-model")
assert model is not None assert model is not None
assert model.identifier == "test-model" assert model.identifier == "test_provider/test-model"
# Test get_object_by_identifier on non-existent object # Test get_object_by_identifier on non-existent object
non_existent = await table.get_object_by_identifier("model", "non-existent-model") non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None assert non_existent is None
await table.unregister_model(model_id="test-model") await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test-model-2") await table.unregister_model(model_id="test_provider/test-model-2")
models = await table.list_models() models = await table.list_models()
assert len(models.data) == 0 assert len(models.data) == 0
@ -160,6 +173,36 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids assert "test-shield-2" in shield_ids
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -245,3 +288,93 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry):
await table.unregister_toolgroup(toolgroup_id="test-toolgroup") await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
tool_groups = await table.list_tool_groups() tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 0 assert len(tool_groups.data) == 0
async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
"""Test alias registration (model_id != provider_model_id) and lookup behavior."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model with alias (model_id different from provider_model_id)
await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
)
# Verify the model was registered with alias as identifier (not namespaced)
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier
assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works
retrieved_model = await table.get_model("my-alias")
assert retrieved_model.identifier == "my-alias"
assert retrieved_model.provider_resource_id == "actual-provider-model"
async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
"""Test registration and lookup with multiple providers having same provider_model_id."""
table = ModelsRoutingTable(
{"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {}
)
await table.initialize()
# Register same provider_model_id on both providers (no aliases)
await table.register_model(model_id="common-model", provider_id="provider1")
await table.register_model(model_id="common-model", provider_id="provider2")
# Verify both models get namespaced identifiers
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert identifiers == {"provider1/common-model", "provider2/common-model"}
# Test lookup by full namespaced identifier works
model1 = await table.get_model("provider1/common-model")
assert model1.provider_id == "provider1"
assert model1.provider_resource_id == "common-model"
model2 = await table.get_model("provider2/common-model")
assert model2.provider_id == "provider2"
assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error
try:
await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
"""Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model without alias (gets namespaced identifier)
await table.register_model(model_id="test-model", provider_id="test_provider")
# Verify namespaced identifier was created
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "test_provider/test-model"
assert model.provider_resource_id == "test-model"
# Test lookup by full namespaced identifier (direct hit via get_object_by_identifier)
retrieved_model = await table.get_model("test_provider/test-model")
assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model")
assert retrieved_model.identifier == "test_provider/test-model"
assert retrieved_model.provider_resource_id == "test-model"
# Test lookup of non-existent model fails
try:
await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)

View file

@ -10,6 +10,8 @@ from unittest.mock import MagicMock
from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching():
with request_provider_data_context( with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
): ):
openai_client = inference_adapter._get_openai_client() openai_client = inference_adapter.client
assert openai_client.api_key == api_key assert openai_client.api_key == api_key
@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching():
assert together_client.client.api_key == api_key assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client() openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key