Merge branch 'main' into acl-vector-stores

This commit is contained in:
Francisco Arceo 2025-07-20 07:37:04 -04:00 committed by GitHub
commit bc835c723c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 1615 additions and 1417 deletions

View file

@ -71,7 +71,7 @@ jobs:
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template starter --image-type venv uv run llama stack build --template ci-tests --image-type venv
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}
@ -92,9 +92,9 @@ jobs:
shell: bash shell: bash
run: | run: |
if [ "${{ matrix.client-type }}" == "library" ]; then if [ "${{ matrix.client-type }}" == "library" ]; then
stack_config="starter" stack_config="ci-tests"
else else
stack_config="server:starter" stack_config="server:ci-tests"
fi fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \

View file

@ -93,7 +93,7 @@ jobs:
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template starter --image-type venv uv run llama stack build --template ci-tests --image-type venv
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}

View file

@ -97,9 +97,9 @@ jobs:
- name: Build a single provider - name: Build a single provider
run: | run: |
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml yq -i '.image_type = "container"' llama_stack/templates/ci-tests/build.yaml
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml yq -i '.image_name = "test"' llama_stack/templates/ci-tests/build.yaml
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
@ -126,14 +126,14 @@ jobs:
.image_type = "container" | .image_type = "container" |
.image_name = "ubi9-test" | .image_name = "ubi9-test" |
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest" .distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
' llama_stack/templates/starter/build.yaml ' llama_stack/templates/ci-tests/build.yaml
- name: Build dev container (UBI9) - name: Build dev container (UBI9)
env: env:
USE_COPY_NOT_MOUNT: "true" USE_COPY_NOT_MOUNT: "true"
LLAMA_STACK_DIR: "." LLAMA_STACK_DIR: "."
run: | run: |
uv run llama stack build --config llama_stack/templates/starter/build.yaml uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |

View file

@ -20,7 +20,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -4,7 +4,6 @@ This section contains documentation for all available providers for the **infere
- [inline::meta-reference](inline_meta-reference.md) - [inline::meta-reference](inline_meta-reference.md)
- [inline::sentence-transformers](inline_sentence-transformers.md) - [inline::sentence-transformers](inline_sentence-transformers.md)
- [inline::vllm](inline_vllm.md)
- [remote::anthropic](remote_anthropic.md) - [remote::anthropic](remote_anthropic.md)
- [remote::bedrock](remote_bedrock.md) - [remote::bedrock](remote_bedrock.md)
- [remote::cerebras](remote_cerebras.md) - [remote::cerebras](remote_cerebras.md)

View file

@ -1,29 +0,0 @@
# inline::vllm
## Description
vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `tensor_parallel_size` | `<class 'int'>` | No | 1 | Number of tensor parallel replicas (number of GPUs to use). |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `max_model_len` | `<class 'int'>` | No | 4096 | Maximum context length to use during serving. |
| `max_num_seqs` | `<class 'int'>` | No | 4 | Maximum parallel batch size for generation. |
| `enforce_eager` | `<class 'bool'>` | No | False | Whether to use eager mode for inference (otherwise cuda graphs are used). |
| `gpu_memory_utilization` | `<class 'float'>` | No | 0.3 | How much GPU memory will be allocated when this provider has finished loading, including memory that was already allocated before loading. |
## Sample Configuration
```yaml
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1}
max_tokens: ${env.MAX_TOKENS:=4096}
max_model_len: ${env.MAX_MODEL_LEN:=4096}
max_num_seqs: ${env.MAX_NUM_SEQS:=4}
enforce_eager: ${env.ENFORCE_EAGER:=False}
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3}
```

View file

@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.VLLM_URL} url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_models( async def update_registered_llm_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],

View file

@ -81,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_models( async def update_registered_llm_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
@ -92,12 +92,16 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of
model_ids = {} model_ids = {}
for model in existing_models: for model in existing_models:
if model.provider_id == provider_id: # we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}") logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model) await self.unregister_object(model)
for model in models: for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids: if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id] model.identifier = model_ids[model.provider_resource_id]

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import AccessRule, Api
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
from .files import LocalfsFilesImpl from .files import LocalfsFilesImpl
@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"] __all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]): async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
impl = LocalfsFilesImpl(config) impl = LocalfsFilesImpl(config, policy)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,16 +19,19 @@ from llama_stack.apis.files import (
OpenAIFileObject, OpenAIFileObject,
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
class LocalfsFilesImpl(Files): class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig) -> None: def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
self.config = config self.config = config
self.sql_store: SqlStore | None = None self.policy = policy
self.sql_store: AuthorizedSqlStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize the files provider by setting up storage directory and metadata database.""" """Initialize the files provider by setting up storage directory and metadata database."""
@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = sqlstore_impl(self.config.metadata_store) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -126,6 +129,7 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
@ -156,7 +160,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")
@ -174,7 +178,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")
@ -197,7 +201,7 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
# Get file metadata # Get file metadata
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -1,53 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider.
Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
description=(
"How much GPU memory will be allocated when this provider has finished "
"loading, including memory that was already allocated before loading."
),
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
"max_tokens": "${env.MAX_TOKENS:=4096}",
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
}

View file

@ -1,170 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import vllm
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _llama_stack_tools_to_openai_tools(
tools: list[ToolDefinition] | None = None,
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
return result
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.
:param request: Bundled request parameters in Llama Stack format.
:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""
converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_config is not None
and request.tool_config.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)
# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k
# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
**sampling_options,
**guided_decoding_options,
**logprob_options,
}
return request_options

View file

@ -1,811 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama import sku_list
from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
# TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite",
"MllamaConfig": "llama3_json",
"LlamaConfig": "llama3_json",
}
DEFAULT_TOOL_PARSER = "pythonic"
logger = get_logger(__name__, category="inference")
def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _response_format_to_guided_decoding_params(
response_format: ResponseFormat | None, # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
"""
if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# value that crashes the executor on some code paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# Translate the types that exist and detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# representation of the grammar.
raise TypeError(
"Constrained decoding with BNF grammars is not currently implemented, because the "
"reference implementation does not implement it."
)
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
def _convert_sampling_params(
sampling_params: SamplingParams | None,
response_format: ResponseFormat | None, # type: ignore
log_prob_config: LogProbConfig | None,
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# Stack dataclasses. These defaults are different from vLLM's defaults.
if sampling_params is None:
sampling_params = SamplingParams()
if log_prob_config is None:
log_prob_config = LogProbConfig()
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
if sampling_params.strategy.top_k == 0:
# vLLM treats "k" differently for top-k sampling
vllm_top_k = -1
else:
vllm_top_k = sampling_params.strategy.top_k
else:
vllm_top_k = -1
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
vllm_top_p = sampling_params.strategy.top_p
# Llama Stack only allows temperature with top-P.
vllm_temperature = sampling_params.strategy.temperature
else:
vllm_top_p = 1.0
vllm_temperature = 0.0
# vLLM allows top-p and top-k at the same time.
vllm_sampling_params = vllm.SamplingParams.from_optional(
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
temperature=vllm_temperature,
top_p=vllm_top_p,
top_k=vllm_top_k,
repetition_penalty=sampling_params.repetition_penalty,
guided_decoding=_response_format_to_guided_decoding_params(response_format),
logprobs=log_prob_config.top_k,
)
return vllm_sampling_params
class VLLMInferenceImpl(
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
"""
config: VLLMConfig
register_helper: ModelRegistryHelper
model_ids: set[str]
resolved_model_id: str | None
engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig):
self.config = config
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider
self.resolved_model_id = None
self.model_ids = set()
self.engine = None
self.chat = None
self.is_meta_llama_model = False
###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
async def initialize(self) -> None:
"""
Callback that is invoked through many levels of indirection during provider class
instantiation, sometime after when __init__() is called and before any model registration
methods or methods connected to a REST API are called.
It's not clear what assumptions the class can make about the platform's initialization
state here that can't be made during __init__(), and vLLM can't be started until we know
what model it's supposed to be serving, so nothing happens here currently.
"""
pass
async def shutdown(self) -> None:
logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
self.chat = None
self.model_ids = set()
self.resolved_model_id = None
###########################################################################
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint with an
inference provider.
:param model: Object that encapsulates parameters necessary for identifying a specific
LLM.
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader
model_id_for_vllm = model.provider_model_id
is_meta_llama_model = False
if self.resolved_model_id is not None:
if model_id_for_vllm != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
f"copies of the provider instead."
)
else:
# Model already loaded
logger.info(
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
)
self.model_ids.add(model.model_id)
return model
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
if is_meta_llama_model:
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
self.is_meta_llama_model = is_meta_llama_model
# If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out.
engine_args = AsyncEngineArgs(
model=model_id_for_vllm,
tokenizer=model_id_for_vllm,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
max_num_seqs=self.config.max_num_seqs,
max_model_len=self.config.max_model_len,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
# parser, we need to determine what model architecture is being used. For now, we infer
# that information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
else:
# No info -- choose a default so we can at least attempt tool
# use.
tool_parser = DEFAULT_TOOL_PARSER
logger.debug(f"{hf_config_class_name=}")
logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config()
self.chat = OpenAIServingChat(
engine_client=self.engine,
model_config=model_config,
models=OpenAIServingModels(
engine_client=self.engine,
model_config=model_config,
base_model_paths=[
# The layer below us will only see resolved model IDs
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
],
),
response_role="assistant",
request_logger=None, # Use default logging
chat_template=None, # Use default template from model checkpoint
enable_auto_tools=True,
tool_parser=tool_parser,
chat_template_content_format="auto",
)
self.resolved_model_id = model_id_for_vllm
self.model_ids.add(model.model_id)
logger.info(f"Finished preloading model: {model_id_for_vllm}")
return model
async def unregister_model(self, model_id: str) -> None:
"""
Callback that is called when the server removes an inference endpoint from an inference
provider.
:param model_id: The same external ID that the higher layers of the stack previously passed
to :func:`register_model()`
"""
if model_id not in self.model_ids:
raise ValueError(
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
)
self.model_ids.remove(model_id)
if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection to vLLM and free up
# resources.
# Note that this operation may cause in-flight chat completion requests on the
# now-unregistered model to return errors.
self.resolved_model_id = None
self.chat = None
self.engine.shutdown_background_loop()
self.engine = None
###########################################################################
# METHODS INHERITED FROM Inference INTERFACE
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if not isinstance(content, str):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
logger.debug(f"{converted_sampling_params=}")
if stream:
return self._streaming_completion(content, converted_sampling_params)
else:
streaming_result = None
async for _ in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
stop_reason=streaming_result.stop_reason,
logprobs=streaming_result.logprobs,
)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message], # type: ignore
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, # type: ignore
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
return await self._chat_completion_for_meta_llama(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
# we drill down to the LLMEngine inside the AsyncLLMEngine.
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
llm_engine = self.engine.engine
tokenizer_group = llm_engine.tokenizer
eos_token_id = tokenizer_group.tokenizer.eos_token_id
request_output: vllm.RequestOutput = None
async for request_output in results_generator:
# Check for weird inference failures
if request_output.outputs is None or len(request_output.outputs) == 0:
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
# us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
# Convert logprobs from vLLM's format to Llama Stack's format
logprobs = [
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the overall generate()
# call completed.
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
if output.stop_reason is None:
stop_reason = None # Still going
elif output.stop_reason == "stop":
stop_reason = StopReason.end_of_turn
elif output.stop_reason == "length":
stop_reason = StopReason.out_of_tokens
elif isinstance(output.stop_reason, int):
# If the model config specifies multiple end-of-sequence tokens, then vLLM
# will return the token ID of the EOS token in the stop_reason field.
stop_reason = StopReason.end_of_turn
else:
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=completion_string,
stop_reason=StopReason.out_of_tokens,
logprobs=logprobs,
)
def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse:
"""
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
the fields that aren't currently present in the Llama Stack dataclass.
"""
# There may be multiple responses, but we can only pass through the first one.
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage(
role=vllm_message.role,
# Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[
ToolCall(
call_id=t.id,
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],
)
# TODO: Convert logprobs
logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse(
completion_message=converted_message,
)
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=request.stream,
logprobs=request.logprobs,
)
if request.stream:
if not isinstance(completion_response_or_iterator, AsyncIterator):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
)
completion_response: CompletionResponse = completion_response_or_iterator
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
async def _chat_completion_for_meta_llama_streaming(
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
method to keep asyncio happy.
"""
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
async def _generate_and_convert_to_openai_compat():
chunk: CompletionResponseStreamChunk # Make Pylance happy
last_text_len = 0
async for chunk in results_iterator:
if chunk.stop_reason == StopReason.end_of_turn:
finish_reason = "stop"
elif chunk.stop_reason == StopReason.end_of_message:
finish_reason = "eos"
elif chunk.stop_reason == StopReason.out_of_tokens:
finish_reason = "length"
else:
finish_reason = None
# Convert delta back to an actual delta
text_delta = chunk.delta[last_text_len:]
last_text_len = len(chunk.delta)
logger.debug(f"{text_delta=}; {finish_reason=}")
yield OpenAICompatCompletionResponse(
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}")
yield chunk
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
"""
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
API into a second async iterator that returns Llama Stack objects.
:param vllm_result: Stream of strings that need to be parsed
"""
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: dict[int, dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
stop_reason=None,
)
)
converted_stop_reason = None
async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
# end with "\n\n".
_prefix = "data: "
_suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
# In between the "data: " and newlines is an event record
data_str = chunk_str[len(_prefix) : -len(_suffix)]
# The end of the stream is indicated with "[DONE]"
if data_str == "[DONE]":
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=converted_stop_reason,
)
)
return
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
if "content" in delta_record:
# Text delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=delta_record["content"]),
stop_reason=converted_stop_reason,
)
)
elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]:
index = tc["index"]
if index not in index_to_tool_call:
# First time this tool call is showing up
index_to_tool_call[index] = dict()
tool_call = index_to_tool_call[index]
if "id" in tc:
tool_call["call_id"] = tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_call["tool_name"] = tc["function"]["name"]
if "arguments" in tc["function"]:
# Arguments comes in as pieces of a string
if "arguments_str" not in tool_call:
tool_call["arguments_str"] = ""
tool_call["arguments_str"] += tc["function"]["arguments"]
else:
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate event per tool
# call.
for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
del tool_call_record["arguments_str"]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
stop_reason=converted_stop_reason,
)
)
# If we get here, we've lost the connection with the vLLM event stream before it ended
# normally.
raise ValueError("vLLM event stream ended without [DONE] message.")

View file

@ -37,16 +37,6 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
description="Meta's reference implementation of inference with support for various model formats and optimization techniques.", description="Meta's reference implementation of inference with support for various model formats and optimization techniques.",
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
description="vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.",
),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="inline::sentence-transformers", provider_type="inline::sentence-transformers",

View file

@ -159,18 +159,18 @@ class OllamaInferenceAdapter(
models = [] models = []
for m in response.models: for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :( if model_type == ModelType.embedding:
# we should likely add a hard-coded mapping of model name to embedding dimension continue
models.append( models.append(
Model( Model(
identifier=m.model, identifier=m.model,
provider_resource_id=m.model, provider_resource_id=m.model,
provider_id=provider_id, provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, metadata={},
model_type=model_type, model_type=model_type,
) )
) )
await self.model_store.update_registered_models(provider_id, models) await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)") logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval) await asyncio.sleep(self.config.refresh_models_interval)

View file

@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=True, default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
) )
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod
@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.VLLM_URL}", url: str = "${env.VLLM_URL:=}",
**kwargs, **kwargs,
): ):
return { return {

View file

@ -3,8 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ModelStore,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData, OpenAIEmbeddingData,
@ -54,6 +55,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():
@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config self.config = config
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized _ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None: if self.client is not None:
return return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}") log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client() self.client = self._create_client()

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .vllm import get_distribution_template # noqa: F401 from .ci_tests import get_distribution_template # noqa: F401

View file

@ -0,0 +1,65 @@
version: 2
distribution_spec:
description: CI tests for Llama Stack
providers:
inference:
- remote::cerebras
- remote::ollama
- remote::vllm
- remote::tgi
- remote::hf::serverless
- remote::hf::endpoint
- remote::fireworks
- remote::together
- remote::bedrock
- remote::databricks
- remote::nvidia
- remote::runpod
- remote::openai
- remote::anthropic
- remote::gemini
- remote::groq
- remote::fireworks-openai-compat
- remote::llama-openai-compat
- remote::together-openai-compat
- remote::groq-openai-compat
- remote::sambanova-openai-compat
- remote::cerebras-openai-compat
- remote::sambanova
- remote::passthrough
- inline::sentence-transformers
vector_io:
- inline::faiss
- inline::sqlite-vec
- inline::milvus
- remote::chromadb
- remote::pgvector
files:
- inline::localfs
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
post_training:
- inline::huggingface
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.templates.template import DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "ci-tests"
template.name = name
template.description = "CI tests for Llama Stack"
return template

File diff suppressed because it is too large Load diff

View file

@ -26,7 +26,7 @@ providers:
- provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_id: ${env.ENABLE_VLLM:=__disabled__}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL} url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -1,35 +0,0 @@
version: 2
distribution_spec:
description: Use a built-in vLLM engine for running LLM inference
providers:
inference:
- inline::vllm
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -1,132 +0,0 @@
version: 2
image_name: vllm-gpu
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: vllm
provider_type: inline::vllm
config:
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1}
max_tokens: ${env.MAX_TOKENS:=4096}
max_model_len: ${env.MAX_MODEL_LEN:=4096}
max_num_seqs: ${env.MAX_NUM_SEQS:=4}
enforce_eager: ${env.ENFORCE_EAGER:=False}
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
ToolGroupInput,
)
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["inline::vllm", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "vllm-gpu"
inference_provider = Provider(
provider_id="vllm",
provider_type="inline::vllm",
config=VLLMConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use a built-in vLLM engine for running LLM inference",
container_image=None,
template_path=None,
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the vLLM engine",
),
"TENSOR_PARALLEL_SIZE": (
"1",
"Number of tensor parallel replicas (number of GPUs to use).",
),
"MAX_TOKENS": (
"4096",
"Maximum number of tokens to generate.",
),
"ENFORCE_EAGER": (
"False",
"Whether to use eager mode for inference (otherwise cuda graphs are used).",
),
"GPU_MEMORY_UTILIZATION": (
"0.7",
"GPU memory utilization for the vLLM engine.",
),
},
)

View file

@ -257,7 +257,6 @@ exclude = [
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/llama_guard/",

View file

@ -5,10 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from io import BytesIO from io import BytesIO
from unittest.mock import patch
import pytest import pytest
from openai import OpenAI from openai import OpenAI
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
@ -61,3 +63,218 @@ def test_openai_client_basic_operations(compat_client, client_with_models):
except Exception: except Exception:
pass pass
raise e raise e
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
def test_files_authentication_isolation(mock_get_authenticated_user, compat_client, client_with_models):
"""Test that users can only access their own files."""
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
client = compat_client
# Create two test users
user1 = User("user1", {"roles": ["user"], "teams": ["team-a"]})
user2 = User("user2", {"roles": ["user"], "teams": ["team-b"]})
# User 1 uploads a file
mock_get_authenticated_user.return_value = user1
test_content_1 = b"User 1's private file content"
with BytesIO(test_content_1) as file_buffer:
file_buffer.name = "user1_file.txt"
user1_file = client.files.create(file=file_buffer, purpose="assistants")
# User 2 uploads a file
mock_get_authenticated_user.return_value = user2
test_content_2 = b"User 2's private file content"
with BytesIO(test_content_2) as file_buffer:
file_buffer.name = "user2_file.txt"
user2_file = client.files.create(file=file_buffer, purpose="assistants")
try:
# User 1 can see their own file
mock_get_authenticated_user.return_value = user1
user1_files = client.files.list()
user1_file_ids = [f.id for f in user1_files.data]
assert user1_file.id in user1_file_ids
assert user2_file.id not in user1_file_ids # Cannot see user2's file
# User 2 can see their own file
mock_get_authenticated_user.return_value = user2
user2_files = client.files.list()
user2_file_ids = [f.id for f in user2_files.data]
assert user2_file.id in user2_file_ids
assert user1_file.id not in user2_file_ids # Cannot see user1's file
# User 1 can retrieve their own file
mock_get_authenticated_user.return_value = user1
retrieved_file = client.files.retrieve(user1_file.id)
assert retrieved_file.id == user1_file.id
# User 1 cannot retrieve user2's file
mock_get_authenticated_user.return_value = user1
with pytest.raises(ValueError, match="not found"):
client.files.retrieve(user2_file.id)
# User 1 can access their file content
mock_get_authenticated_user.return_value = user1
content_response = client.files.content(user1_file.id)
if isinstance(content_response, str):
content = bytes(content_response, "utf-8")
else:
content = content_response.content
assert content == test_content_1
# User 1 cannot access user2's file content
mock_get_authenticated_user.return_value = user1
with pytest.raises(ValueError, match="not found"):
client.files.content(user2_file.id)
# User 1 can delete their own file
mock_get_authenticated_user.return_value = user1
delete_response = client.files.delete(user1_file.id)
assert delete_response.deleted is True
# User 1 cannot delete user2's file
mock_get_authenticated_user.return_value = user1
with pytest.raises(ValueError, match="not found"):
client.files.delete(user2_file.id)
# User 2 can still access their file after user1's file is deleted
mock_get_authenticated_user.return_value = user2
retrieved_file = client.files.retrieve(user2_file.id)
assert retrieved_file.id == user2_file.id
# Cleanup user2's file
mock_get_authenticated_user.return_value = user2
client.files.delete(user2_file.id)
except Exception as e:
# Cleanup in case of failure
try:
mock_get_authenticated_user.return_value = user1
client.files.delete(user1_file.id)
except Exception:
pass
try:
mock_get_authenticated_user.return_value = user2
client.files.delete(user2_file.id)
except Exception:
pass
raise e
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
def test_files_authentication_shared_attributes(mock_get_authenticated_user, compat_client, client_with_models):
"""Test access control with users having identical attributes."""
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
client = compat_client
# Create users with identical attributes (required for default policy)
user_a = User("user-a", {"roles": ["user"], "teams": ["shared-team"]})
user_b = User("user-b", {"roles": ["user"], "teams": ["shared-team"]})
# User A uploads a file
mock_get_authenticated_user.return_value = user_a
test_content = b"Shared attributes file content"
with BytesIO(test_content) as file_buffer:
file_buffer.name = "shared_attributes_file.txt"
shared_file = client.files.create(file=file_buffer, purpose="assistants")
try:
# User B with identical attributes can access the file
mock_get_authenticated_user.return_value = user_b
files_list = client.files.list()
file_ids = [f.id for f in files_list.data]
# User B should be able to see the file due to identical attributes
assert shared_file.id in file_ids
# User B can retrieve file info
retrieved_file = client.files.retrieve(shared_file.id)
assert retrieved_file.id == shared_file.id
# User B can access file content
content_response = client.files.content(shared_file.id)
if isinstance(content_response, str):
content = bytes(content_response, "utf-8")
else:
content = content_response.content
assert content == test_content
# Cleanup
mock_get_authenticated_user.return_value = user_a
client.files.delete(shared_file.id)
except Exception as e:
# Cleanup in case of failure
try:
mock_get_authenticated_user.return_value = user_a
client.files.delete(shared_file.id)
except Exception:
pass
try:
mock_get_authenticated_user.return_value = user_b
client.files.delete(shared_file.id)
except Exception:
pass
raise e
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
def test_files_authentication_anonymous_access(mock_get_authenticated_user, compat_client, client_with_models):
"""Test anonymous user behavior when no authentication is present."""
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
client = compat_client
# Simulate anonymous user (no authentication)
mock_get_authenticated_user.return_value = None
test_content = b"Anonymous file content"
with BytesIO(test_content) as file_buffer:
file_buffer.name = "anonymous_file.txt"
anonymous_file = client.files.create(file=file_buffer, purpose="assistants")
try:
# Anonymous user should be able to access their own uploaded file
files_list = client.files.list()
file_ids = [f.id for f in files_list.data]
assert anonymous_file.id in file_ids
# Can retrieve file info
retrieved_file = client.files.retrieve(anonymous_file.id)
assert retrieved_file.id == anonymous_file.id
# Can access file content
content_response = client.files.content(anonymous_file.id)
if isinstance(content_response, str):
content = bytes(content_response, "utf-8")
else:
content = content_response.content
assert content == test_content
# Can delete the file
delete_response = client.files.delete(anonymous_file.id)
assert delete_response.deleted is True
except Exception as e:
# Cleanup in case of failure
try:
client.files.delete(anonymous_file.id)
except Exception:
pass
raise e

View file

@ -9,6 +9,7 @@ import pytest
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.providers.inline.files.localfs import ( from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl, LocalfsFilesImpl,
LocalfsFilesImplConfig, LocalfsFilesImplConfig,
@ -38,7 +39,7 @@ async def files_provider(tmp_path):
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()) storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
) )
provider = LocalfsFilesImpl(config) provider = LocalfsFilesImpl(config, default_policy())
await provider.initialize() await provider.initialize()
yield provider yield provider