Merge branch 'main' into small-ui-patches

This commit is contained in:
Francisco Arceo 2025-05-15 14:10:12 -06:00 committed by GitHub
commit e9cce9ed38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 1825 additions and 760 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning

2
.github/TRIAGERS.md vendored
View file

@ -1,2 +1,2 @@
# This file documents Triage members in the Llama Stack community # This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb @bbrowning @booxter @franciscojavierarceo @leseb

View file

@ -47,8 +47,8 @@ jobs:
- name: Create provider configuration - name: Create provider configuration
run: | run: |
mkdir -p /tmp/providers.d/remote/inference mkdir -p /home/runner/.llama/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
- name: Build distro from config file - name: Build distro from config file
run: | run: |
@ -66,7 +66,7 @@ jobs:
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |
for i in {1..30}; do for i in {1..30}; do
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
echo "Waiting for Llama Stack server to load the provider..." echo "Waiting for Llama Stack server to load the provider..."
sleep 1 sleep 1
else else

View file

@ -14,6 +14,8 @@ on:
- 'docs/**' - 'docs/**'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
tags:
- '*'
pull_request: pull_request:
branches: branches:
- main - main
@ -61,7 +63,10 @@ jobs:
response=$(curl -X POST \ response=$(curl -X POST \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d "{\"token\": \"$TOKEN\"}" \ -d "{
\"token\": \"$TOKEN\",
\"version\": \"$GITHUB_REF_NAME\"
}" \
https://readthedocs.org/api/v2/webhook/llama-stack/289768/) https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
echo "Response: $response" echo "Response: $response"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -179,6 +179,35 @@ def _validate_has_ellipsis(method) -> str | None:
if "..." not in source and not "NotImplementedError" in source: if "..." not in source and not "NotImplementedError" in source:
return "does not contain ellipsis (...) in its implementation" return "does not contain ellipsis (...) in its implementation"
def _validate_has_return_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is not None and return_type != type(None) and ":returns:" not in source:
return "does not have a ':returns:' in its docstring"
def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method)
sig = inspect.signature(method)
# Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring"
def _validate_has_no_return_none_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is None and ":returns: None" in source:
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
def _validate_docstring_lines_end_with_dot(method) -> str | None:
docstring = inspect.getdoc(method)
if docstring is None:
return None
lines = docstring.split('\n')
for line in lines:
line = line.strip()
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
_VALIDATORS = { _VALIDATORS = {
"GET": [ "GET": [
@ -186,13 +215,23 @@ _VALIDATORS = {
_validate_list_parameters_contain_data, _validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list, _validate_api_method_doesnt_return_list,
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
"DELETE": [ "DELETE": [
_validate_api_delete_method_returns_none, _validate_api_delete_method_returns_none,
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring
], ],
"POST": [ "POST": [
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
} }

View file

@ -110,6 +110,8 @@ html_theme_options = {
"canonical_url": "https://github.com/meta-llama/llama-stack", "canonical_url": "https://github.com/meta-llama/llama-stack",
"collapse_navigation": False, "collapse_navigation": False,
# "style_nav_header_background": "#c3c9d4", # "style_nav_header_background": "#c3c9d4",
'display_version': True,
'version_selector': True,
} }
default_dark_mode = False default_dark_mode = False

View file

@ -178,7 +178,7 @@ image_name: ollama
image_type: conda image_type: conda
# If some providers are external, you can specify the path to the implementation # If some providers are external, you can specify the path to the implementation
external_providers_dir: /etc/llama-stack/providers.d external_providers_dir: ~/.llama/providers.d
``` ```
``` ```
@ -206,7 +206,7 @@ distribution_spec:
image_type: container image_type: container
image_name: ci-test image_name: ci-test
# Path to external provider implementations # Path to external provider implementations
external_providers_dir: /etc/llama-stack/providers.d external_providers_dir: ~/.llama/providers.d
``` ```
Here's an example for a custom Ollama provider: Here's an example for a custom Ollama provider:

View file

@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
```yaml ```yaml
external_providers_dir: /etc/llama-stack/providers.d/ external_providers_dir: ~/.llama/providers.d/
``` ```
## Directory Structure ## Directory Structure
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
3. Create the provider specification: 3. Create the provider specification:
```yaml ```yaml
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml # ~/.llama/providers.d/remote/inference/custom_ollama.yaml
adapter: adapter:
adapter_type: custom_ollama adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"] pip_packages: ["ollama", "aiohttp"]
@ -201,7 +201,7 @@ uv pip install -e .
5. Configure Llama Stack to use external providers: 5. Configure Llama Stack to use external providers:
```yaml ```yaml
external_providers_dir: /etc/llama-stack/providers.d/ external_providers_dir: ~/.llama/providers.d/
``` ```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`. The provider will now be available in Llama Stack with the type `remote::custom_ollama`.

View file

@ -38,6 +38,67 @@ wait_for_service() {
return 0 return 0
} }
usage() {
cat << EOF
📚 Llama-Stack Deployment Script
Description:
This script sets up and deploys Llama-Stack with Ollama integration in containers.
It handles both Docker and Podman runtimes and includes automatic platform detection.
Usage:
$(basename "$0") [OPTIONS]
Options:
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
-h, --help Show this help message
For more information:
Documentation: https://llama-stack.readthedocs.io/
GitHub: https://github.com/meta-llama/llama-stack
Report issues:
https://github.com/meta-llama/llama-stack/issues
EOF
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
usage
exit 0
;;
-p|--port)
PORT="$2"
shift 2
;;
-o|--ollama-port)
OLLAMA_PORT="$2"
shift 2
;;
-m|--model)
MODEL_ALIAS="$2"
shift 2
;;
-i|--image)
SERVER_IMAGE="$2"
shift 2
;;
-t|--timeout)
WAIT_TIMEOUT="$2"
shift 2
;;
*)
die "Unknown option: $1"
;;
esac
done
if command -v docker &> /dev/null; then if command -v docker &> /dev/null; then
ENGINE="docker" ENGINE="docker"
elif command -v podman &> /dev/null; then elif command -v podman &> /dev/null; then

View file

@ -413,7 +413,7 @@ class Agents(Protocol):
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request. :param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config. :param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object. :returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
""" """
... ...
@ -509,6 +509,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to get. :param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for. :param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by. :param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
""" """
... ...
@ -606,5 +607,6 @@ class Agents(Protocol):
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:returns: An OpenAIResponseObject.
""" """
... ...

View file

@ -38,7 +38,17 @@ class BatchInference(Protocol):
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -52,4 +62,17 @@ class BatchInference(Protocol):
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Benchmarks(Protocol): class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET") @webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ... async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Benchmark: ... ) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST") @webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark( async def register_benchmark(
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
provider_benchmark_id: str | None = None, provider_benchmark_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> None: ... ) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...

View file

@ -34,14 +34,21 @@ class DatasetIO(Protocol):
- limit: Number of items to return. If None or -1, returns all items. - limit: Number of items to return. If None or -1, returns all items.
The response includes: The response includes:
- data: List of items for the current page - data: List of items for the current page.
- has_more: Whether there are more items available after this set - has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from. :param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None. :param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get. :param limit: The number of rows to get.
:returns: A PaginatedResponse.
""" """
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -137,7 +137,8 @@ class Datasets(Protocol):
""" """
Register a new dataset. Register a new dataset.
:param purpose: The purpose of the dataset. One of :param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training. - "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{ {
"messages": [ "messages": [
@ -188,8 +189,9 @@ class Datasets(Protocol):
] ]
} }
:param metadata: The metadata for the dataset. :param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"} - E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated. :param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
""" """
... ...
@ -197,13 +199,29 @@ class Datasets(Protocol):
async def get_dataset( async def get_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> Dataset: ... ) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET") @webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ... async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE") @webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset( async def unregister_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> None: ... ) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -93,7 +93,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation. :returns: The job that was created to run the evaluation.
""" """
... ...
@ -111,7 +111,7 @@ class Eval(Protocol):
:param input_rows: The rows to evaluate. :param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation. :param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores :returns: EvaluateResponse object containing generations and scores.
""" """
... ...
@ -121,7 +121,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of. :param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob. :returns: The status of the evaluation job.
""" """
... ...
@ -140,6 +140,6 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of. :param job_id: The ID of the job to get the result of.
:return: The result of the job. :returns: The result of the job.
""" """
... ...

View file

@ -91,10 +91,11 @@ class Files(Protocol):
""" """
Create a new upload session for a file identified by a bucket and key. Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-) :param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:param mime_type: MIME type of the file :param mime_type: MIME type of the file.
:param size: File size in bytes :param size: File size in bytes.
:returns: A FileUploadResponse.
""" """
... ...
@ -107,7 +108,8 @@ class Files(Protocol):
Upload file content to an existing upload session. Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded. On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileResponse or None if the upload is not complete.
""" """
... ...
@ -117,9 +119,10 @@ class Files(Protocol):
upload_id: str, upload_id: str,
) -> FileUploadResponse: ) -> FileUploadResponse:
""" """
Returns information about an existsing upload session Returns information about an existsing upload session.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileUploadResponse.
""" """
... ...
@ -130,6 +133,9 @@ class Files(Protocol):
) -> ListBucketResponse: ) -> ListBucketResponse:
""" """
List all buckets. List all buckets.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListBucketResponse.
""" """
... ...
@ -141,7 +147,8 @@ class Files(Protocol):
""" """
List all files in a bucket. List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListFileResponse.
""" """
... ...
@ -154,8 +161,9 @@ class Files(Protocol):
""" """
Get a file info identified by a bucket and key. Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:returns: A FileResponse.
""" """
... ...
@ -168,7 +176,7 @@ class Files(Protocol):
""" """
Delete a file identified by a bucket and key. Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
""" """
... ...

View file

@ -845,13 +845,13 @@ class Inference(Protocol):
"""Generate a completion for the given content using the specified model. """Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for :param content: The content to generate a completion for.
:param sampling_params: (Optional) Parameters to control the sampling strategy :param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding :param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion. :returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
""" """
... ...
@ -864,6 +864,15 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
@ -883,9 +892,9 @@ class Inference(Protocol):
"""Generate a chat completion for the given messages using the specified model. """Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy :param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model :param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated:: .. deprecated::
Use tool_config instead. Use tool_config instead.
@ -902,7 +911,7 @@ class Inference(Protocol):
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use. :param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion. :returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
""" """
... ...
@ -917,6 +926,17 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
@ -935,7 +955,7 @@ class Inference(Protocol):
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models. :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length. :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models. :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
""" """
... ...
@ -967,22 +987,23 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for :param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate :param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt :param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAICompletion.
""" """
... ...
@ -1016,27 +1037,28 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use :param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use :param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate :param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls :param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use :param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use :param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use :param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use :param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
""" """
... ...

View file

@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET") @webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ... async def list_routes(self) -> ListRoutesResponse:
"""List all routes.
:returns: A ListRoutesResponse.
"""
...
@webmethod(route="/health", method="GET") @webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ... async def health(self) -> HealthInfo:
"""Get the health of the service.
:returns: A HealthInfo.
"""
...
@webmethod(route="/version", method="GET") @webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ... async def version(self) -> VersionInfo:
"""Get the version of the service.
:returns: A VersionInfo.
"""
...

View file

@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
@trace_protocol @trace_protocol
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models", method="GET") @webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ... async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET") @webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ... async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET") @webmethod(route="/models/{model_id:path}", method="GET")
async def get_model( async def get_model(
self, self,
model_id: str, model_id: str,
) -> Model: ... ) -> Model:
"""Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST") @webmethod(route="/models", method="POST")
async def register_model( async def register_model(
@ -99,10 +115,25 @@ class Models(Protocol):
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> Model: ... ) -> Model:
"""Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE") @webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model( async def unregister_model(
self, self,
model_id: str, model_id: str,
) -> None: ... ) -> None:
"""Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -182,7 +182,19 @@ class PostTraining(Protocol):
), ),
checkpoint_dir: str | None = None, checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None, algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
@ -193,16 +205,49 @@ class PostTraining(Protocol):
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -32,7 +32,18 @@ class Providers(Protocol):
""" """
@webmethod(route="/providers", method="GET") @webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ... async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET") @webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ... async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -54,4 +54,12 @@ class Safety(Protocol):
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any], params: dict[str, Any],
) -> RunShieldResponse: ... ) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...

View file

@ -61,7 +61,15 @@ class Scoring(Protocol):
dataset_id: str, dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None], scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
@ -73,6 +81,6 @@ class Scoring(Protocol):
:param input_rows: The rows to score. :param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring. :param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results :returns: A ScoreResponse object containing rows and aggregated results.
""" """
... ...

View file

@ -134,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET") @webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(
@ -148,4 +159,14 @@ class ScoringFunctions(Protocol):
provider_scoring_fn_id: str | None = None, provider_scoring_fn_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: ScoringFnParams | None = None, params: ScoringFnParams | None = None,
) -> None: ... ) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...

View file

@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
@trace_protocol @trace_protocol
class Shields(Protocol): class Shields(Protocol):
@webmethod(route="/shields", method="GET") @webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ... async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET") @webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ... async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST") @webmethod(route="/shields", method="POST")
async def register_shield( async def register_shield(
@ -58,4 +69,13 @@ class Shields(Protocol):
provider_shield_id: str | None = None, provider_shield_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> Shield: ... ) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...

View file

@ -247,7 +247,17 @@ class QueryMetricsResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST") @webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ... async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(route="/telemetry/traces", method="POST") @webmethod(route="/telemetry/traces", method="POST")
async def query_traces( async def query_traces(
@ -256,13 +266,35 @@ class Telemetry(Protocol):
limit: int | None = 100, limit: int | None = 100,
offset: int | None = 0, offset: int | None = 0,
order_by: list[str] | None = None, order_by: list[str] | None = None,
) -> QueryTracesResponse: ... ) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ... async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree( async def get_span_tree(
@ -270,7 +302,15 @@ class Telemetry(Protocol):
span_id: str, span_id: str,
attributes_to_return: list[str] | None = None, attributes_to_return: list[str] | None = None,
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpanTreeResponse: ... ) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(route="/telemetry/spans", method="POST") @webmethod(route="/telemetry/spans", method="POST")
async def query_spans( async def query_spans(
@ -278,7 +318,15 @@ class Telemetry(Protocol):
attribute_filters: list[QueryCondition], attribute_filters: list[QueryCondition],
attributes_to_return: list[str], attributes_to_return: list[str],
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpansResponse: ... ) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST") @webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset( async def save_spans_to_dataset(
@ -287,7 +335,15 @@ class Telemetry(Protocol):
attributes_to_save: list[str], attributes_to_save: list[str],
dataset_id: str, dataset_id: str,
max_depth: int | None = None, max_depth: int | None = None,
) -> None: ... ) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST") @webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
async def query_metrics( async def query_metrics(
@ -298,4 +354,15 @@ class Telemetry(Protocol):
granularity: str | None = "1d", granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE, query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None, label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ... ) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -103,37 +103,65 @@ class ToolGroups(Protocol):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
... ...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group( async def get_tool_group(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> ToolGroup: ... ) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET") @webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider""" """List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
... ...
@webmethod(route="/tools", method="GET") @webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group""" """List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
"""
... ...
@webmethod(route="/tools/{tool_name:path}", method="GET") @webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool( async def get_tool(
self, self,
tool_name: str, tool_name: str,
) -> Tool: ... ) -> Tool:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup( async def unregister_toolgroup(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> None: ) -> None:
"""Unregister a tool group""" """Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
... ...
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ... ) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
... ...

View file

@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
@trace_protocol @trace_protocol
class VectorDBs(Protocol): class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET") @webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ... async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db( async def get_vector_db(
self, self,
vector_db_id: str, vector_db_id: str,
) -> VectorDB: ... ) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST") @webmethod(route="/vector-dbs", method="POST")
async def register_vector_db( async def register_vector_db(
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ... ) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ... async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -46,7 +46,14 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ... ) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param ttl_seconds: The time to live of the chunks.
"""
...
@webmethod(route="/vector-io/query", method="POST") @webmethod(route="/vector-io/query", method="POST")
async def query_chunks( async def query_chunks(
@ -54,4 +61,12 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ... ) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...

View file

@ -36,7 +36,8 @@ from llama_stack.distribution.datatypes import (
) )
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
@ -202,7 +203,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
else: else:
with open(args.config) as f: with open(args.config) as f:
try: try:
build_config = BuildConfig(**yaml.safe_load(f)) contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
except Exception as e: except Exception as e:
cprint( cprint(
f"Could not parse config file {args.config}: {e}", f"Could not parse config file {args.config}: {e}",
@ -248,6 +251,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
run_config = Path(run_config) run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text()) config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_command(run_args) run_command(run_args)
@ -267,7 +272,9 @@ def _generate_run_config(
image_name=image_name, image_name=image_name,
apis=apis, apis=apis,
providers={}, providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None, external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
) )
# build providers dict # build providers dict
provider_registry = get_provider_registry(build_config) provider_registry = get_provider_registry(build_config)

View file

@ -33,7 +33,8 @@ class StackRun(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"config", "config",
type=str, type=str,
help="Path to config file to use for the run", nargs="?", # Make it optional
help="Path to config file to use for the run. Required for venv and conda environments.",
) )
self.parser.add_argument( self.parser.add_argument(
"--port", "--port",
@ -82,44 +83,55 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
image_type, image_name = self._get_image_type_and_name(args) image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly # If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages. # using the current environment packages.
if not image_type and not image_name: if not image_type and not image_name:
@ -141,7 +153,10 @@ class StackRun(Subcommand):
else: else:
run_args = formulate_run_args(image_type, image_name, config, template_name) run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)]) run_args.extend([str(args.port)])
if config_file:
run_args.extend(["--config", str(config_file)])
if args.env: if args.env:
for env_var in args.env: for env_var in args.env:

View file

@ -154,6 +154,12 @@ get_python_cmd() {
fi fi
} }
# Add other required item commands generic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama/providers.d /.cache
EOF
if [ -n "$run_config" ]; then if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path # Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml" cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
@ -166,17 +172,19 @@ EOF
# and update the configuration to reference the new container path # and update the configuration to reference the new container path
python_cmd=$(get_python_cmd) python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')") external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
if [ -n "$external_providers_dir" ]; then external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir" echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF add_to_container << EOF
COPY $external_providers_dir /app/providers.d COPY providers.d /.llama/providers.d
EOF EOF
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d # Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak" rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else else
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi fi
fi fi
fi fi
@ -255,9 +263,6 @@ fi
# Add other require item commands genearic to all containers # Add other require item commands genearic to all containers
add_to_container << EOF add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
@ -170,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**config_dict) return StackRunConfig(**config_dict)

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Annotated, Any from typing import Annotated, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -312,11 +313,20 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server", description="Configuration for the HTTP(S) server",
) )
external_providers_dir: str | None = Field( external_providers_dir: Path | None = Field(
default=None, default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -145,7 +145,7 @@ def get_provider_registry(
# Check if config has the external_providers_dir attribute # Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir) external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir): if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}") logger.info(f"Loading external providers from {external_providers_dir}")

View file

@ -29,7 +29,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>" echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
exit 1 exit 1
fi fi
@ -40,37 +40,51 @@ env_path_or_name="$1"
container_image="localhost/$env_path_or_name" container_image="localhost/$env_path_or_name"
shift shift
yaml_config="$1"
shift
port="$1" port="$1"
shift shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")") SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh" source "$SCRIPT_DIR/common.sh"
# Initialize env_vars as an string # Initialize variables
yaml_config=""
env_vars="" env_vars=""
other_args="" other_args=""
# Process environment variables from --env arguments
# Process remaining arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case "$1" in case "$1" in
--env) --config|--yaml-config)
if [[ -n "$2" ]]; then
if [[ -n "$2" ]]; then yaml_config="$2"
env_vars="$env_vars --env $2" shift 2
shift 2 else
else echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2 exit 1
exit 1 fi
fi ;;
;; --env)
*) if [[ -n "$2" ]]; then
other_args="$other_args $1" env_vars="$env_vars --env $2"
shift shift 2
;; else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac esac
done done
# Check if yaml_config is required based on env_type
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
exit 1
fi
PYTHON_BINARY="python" PYTHON_BINARY="python"
case "$env_type" in case "$env_type" in
"venv") "venv")
@ -106,8 +120,14 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--yaml-config $yaml_config"
else
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ $yaml_config_arg \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \
$other_args $other_args
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ # Build the command with optional yaml config
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \ $mounts \
--env LLAMA_STACK_PORT=$port \ --env LLAMA_STACK_PORT=$port \
--entrypoint python \ --entrypoint python \
$container_image:$version_tag \ $container_image:$version_tag \
-m llama_stack.distribution.server.server \ -m llama_stack.distribution.server.server"
--yaml-config /app/config.yaml \
$other_args # Add yaml config if provided, otherwise use default
if [ -n "$yaml_config" ]; then
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
else
cmd="$cmd --yaml-config /app/run.yaml"
fi
# Add any other args
cmd="$cmd $other_args"
# Execute the command
eval $cmd
fi fi

View file

@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime" RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list: def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = "" env_name = ""
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image: if image_type == LlamaStackImageType.CONTAINER.value:
env_name = f"distribution-{template_name}" if template_name else config.container_image env_name = (
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
)
elif image_type == LlamaStackImageType.CONDA.value: elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env env_name = image_name or current_conda_env

View file

@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None, finish_reason: str | None,
last_chunk_content: str | None, last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType, current_event_type: ChatCompletionResponseEventType,
tool_call_buf: UnparseableToolCall, tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]: ) -> list[OpenAIChatCompletionChunk]:
chunks = [] chunks = []
@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream(
else: else:
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
if tool_call_buf.tool_name: tool_call_bufs = tool_call_bufs or {}
# at least one tool call request is received for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}" args_str = tool_call_buf.arguments or "{}"
try: try:
args = json.loads(args_str) args = json.loads(args_str)
@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream(
async def _process_vllm_chat_completion_stream_response( async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None], stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator: ) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start yield ChatCompletionResponseStreamChunk(
tool_call_buf = UnparseableToolCall() event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False end_of_stream_processed = False
async for chunk in stream: async for chunk in stream:
@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response(
return return
choice = chunk.choices[0] choice = chunk.choices[0]
if choice.delta.tool_calls: if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0]) for delta_tool_call in choice.delta.tool_calls:
tool_call_buf.tool_name += str(tool_call.tool_name) tool_call = convert_tool_call(delta_tool_call)
tool_call_buf.call_id += tool_call.call_id if delta_tool_call.index not in tool_call_bufs:
# TODO: remove str() when dict type for 'arguments' is no longer allowed tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf.arguments += str(tool_call.arguments) tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason: if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream( chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason, finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content, last_chunk_content=choice.delta.content,
current_event_type=event_type, current_event_type=event_type,
tool_call_buf=tool_call_buf, tool_call_bufs=tool_call_bufs,
) )
for c in chunks: for c in chunks:
yield c yield c
@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response(
# the stream ended without a chunk containing finish_reason - we have to generate the # the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually # respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream( chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
) )
for c in chunks: for c in chunks:
yield c yield c

View file

@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append( result["tool_calls"].append(
{ {
"id": tc.call_id, "id": tc.call_id,
"type": "function", "type": "function",
"function": { "function": {
"name": tool_name, "name": tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), "arguments": arguments,
}, },
} }
) )

View file

@ -152,46 +152,6 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu" "torch torchvision --index-url https://download.pytorch.org/whl/cpu"
], ],
"dev": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [ "fireworks": [
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
@ -642,6 +602,46 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu" "torch torchvision --index-url https://download.pytorch.org/whl/cpu"
], ],
"starter": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [ "tgi": [
"aiohttp", "aiohttp",
"aiosqlite", "aiosqlite",

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .dev import get_distribution_template # noqa: F401 from .starter import get_distribution_template # noqa: F401

View file

@ -1,6 +1,6 @@
version: '2' version: '2'
distribution_spec: distribution_spec:
description: Distribution for running e2e tests in CI description: Quick start template for running Llama Stack with several popular providers
providers: providers:
inference: inference:
- remote::openai - remote::openai

View file

@ -1,5 +1,5 @@
version: '2' version: '2'
image_name: dev image_name: starter
apis: apis:
- agents - agents
- datasetio - datasetio
@ -46,7 +46,7 @@ providers:
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb} - provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
@ -71,14 +71,14 @@ providers:
persistence_store: persistence_store:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db
telemetry: telemetry:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
service_name: ${env.OTEL_SERVICE_NAME:} service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/trace_store.db
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -86,7 +86,7 @@ providers:
kvstore: kvstore:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/meta_reference_eval.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/meta_reference_eval.db
datasetio: datasetio:
- provider_id: huggingface - provider_id: huggingface
provider_type: remote::huggingface provider_type: remote::huggingface
@ -94,14 +94,14 @@ providers:
kvstore: kvstore:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/huggingface_datasetio.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/huggingface_datasetio.db
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/localfs_datasetio.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/localfs_datasetio.db
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
@ -132,7 +132,7 @@ providers:
config: {} config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: openai/gpt-4o

View file

@ -46,6 +46,7 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
) )
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
RunConfigSettings, RunConfigSettings,
@ -53,7 +54,7 @@ from llama_stack.templates.template import (
) )
def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]: def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional # in this template, we allow each API key to be optional
providers = [ providers = [
( (
@ -119,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::model-context-protocol", "remote::model-context-protocol",
], ],
} }
name = "dev" name = "starter"
vector_io_providers = [ vector_io_providers = [
Provider( Provider(
@ -171,7 +172,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
description="Distribution for running e2e tests in CI", description="Quick start template for running Llama Stack with several popular providers",
container_image=None, container_image=None,
template_path=None, template_path=None,
providers=providers, providers=providers,

View file

@ -5,13 +5,13 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components.
## Getting Started ## Getting Started
## Install the NPM dependencies First, install dependencies:
```bash ```bash
npm install npm install next react react-dom
``` ```
Run the development server: Then, run the development server:
```bash ```bash
npm run dev npm run dev

View file

@ -45,10 +45,10 @@ export function AppSidebar() {
{logItems.map((item) => ( {logItems.map((item) => (
<SidebarMenuItem key={item.title}> <SidebarMenuItem key={item.title}>
<SidebarMenuButton asChild> <SidebarMenuButton asChild>
<a href={item.url}> <Link href={item.url}>
<item.icon /> <item.icon />
<span>{item.title}</span> <span>{item.title}</span>
</a> </Link>
</SidebarMenuButton> </SidebarMenuButton>
</SidebarMenuItem> </SidebarMenuItem>
))} ))}

View file

@ -304,7 +304,6 @@ exclude = [
"^llama_stack/strong_typing/inspection\\.py$", "^llama_stack/strong_typing/inspection\\.py$",
"^llama_stack/strong_typing/schema\\.py$", "^llama_stack/strong_typing/schema\\.py$",
"^llama_stack/strong_typing/serializer\\.py$", "^llama_stack/strong_typing/serializer\\.py$",
"^llama_stack/templates/dev/dev\\.py$",
"^llama_stack/templates/groq/groq\\.py$", "^llama_stack/templates/groq/groq\\.py$",
"^llama_stack/templates/llama_api/llama_api\\.py$", "^llama_stack/templates/llama_api/llama_api\\.py$",
"^llama_stack/templates/sambanova/sambanova\\.py$", "^llama_stack/templates/sambanova/sambanova\\.py$",

View file

@ -6,4 +6,4 @@ distribution_spec:
- remote::custom_ollama - remote::custom_ollama
image_type: container image_type: container
image_name: ci-test image_name: ci-test
external_providers_dir: /tmp/providers.d external_providers_dir: ~/.llama/providers.d

View file

@ -91,4 +91,4 @@ tool_groups:
provider_id: wolfram-alpha provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321
external_providers_dir: /tmp/providers.d external_providers_dir: ~/.llama/providers.d

View file

@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
assert found_tool_execution assert found_tool_execution
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the boiling point of polyjuice?", "content": "What is the boiling point of the liquid polyjuice in celsius?",
}, },
], ],
session_id=session_id, session_id=session_id,
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the boiling point of polyjuice?", "content": "What is the boiling point of the liquid polyjuice in celsius?",
}, },
], ],
session_id=session_id, session_id=session_id,
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
def test_multi_tool_calls(llama_stack_client, agent_config): def test_multi_tool_calls(llama_stack_client, agent_config):
if "gpt" not in agent_config["model"]: if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
pytest.xfail("Only tested on GPT models") pytest.xfail("Only tested on GPT and Llama 4 models")
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
}, },
], ],
session_id=session_id, session_id=session_id,
stream=False, stream=False,
) )
steps = response.steps steps = response.steps
assert len(steps) == 7
assert steps[0].step_type == "shield_call"
assert steps[1].step_type == "inference"
assert steps[2].step_type == "shield_call"
assert steps[3].step_type == "tool_execution"
assert steps[4].step_type == "shield_call"
assert steps[5].step_type == "inference"
assert steps[6].step_type == "shield_call"
tool_execution_step = steps[3] has_input_shield = agent_config.get("input_shields")
has_output_shield = agent_config.get("output_shields")
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
if has_input_shield:
assert steps[0].step_type == "shield_call"
steps.pop(0)
assert steps[0].step_type == "inference"
if has_output_shield:
assert steps[1].step_type == "shield_call"
steps.pop(1)
assert steps[1].step_type == "tool_execution"
tool_execution_step = steps[1]
if has_input_shield:
assert steps[2].step_type == "shield_call"
steps.pop(2)
assert steps[2].step_type == "inference"
if has_output_shield:
assert steps[3].step_type == "shield_call"
steps.pop(3)
assert len(tool_execution_step.tool_calls) == 2 assert len(tool_execution_step.tool_calls) == 2
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")

View file

@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta, ChoiceDelta as OpenAIChoiceDelta,
) )
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1 assert len(chunks) == 2
assert chunks[0].event.stop_reason == StopReason.end_of_turn assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 0 assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
def test_chat_completion_doesnt_block_event_loop(caplog): def test_chat_completion_doesnt_block_event_loop(caplog):
@ -369,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -422,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -471,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name