Merge branch 'main' into small-ui-patches

This commit is contained in:
Francisco Arceo 2025-05-15 14:10:12 -06:00 committed by GitHub
commit e9cce9ed38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 1825 additions and 760 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning

2
.github/TRIAGERS.md vendored
View file

@ -1,2 +1,2 @@
# This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb
@bbrowning @booxter @franciscojavierarceo @leseb

View file

@ -47,8 +47,8 @@ jobs:
- name: Create provider configuration
run: |
mkdir -p /tmp/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
mkdir -p /home/runner/.llama/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
- name: Build distro from config file
run: |
@ -66,7 +66,7 @@ jobs:
- name: Wait for Llama Stack server to be ready
run: |
for i in {1..30}; do
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
echo "Waiting for Llama Stack server to load the provider..."
sleep 1
else

View file

@ -14,6 +14,8 @@ on:
- 'docs/**'
- 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml'
tags:
- '*'
pull_request:
branches:
- main
@ -61,7 +63,10 @@ jobs:
response=$(curl -X POST \
-H "Content-Type: application/json" \
-d "{\"token\": \"$TOKEN\"}" \
-d "{
\"token\": \"$TOKEN\",
\"version\": \"$GITHUB_REF_NAME\"
}" \
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
echo "Response: $response"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -179,6 +179,35 @@ def _validate_has_ellipsis(method) -> str | None:
if "..." not in source and not "NotImplementedError" in source:
return "does not contain ellipsis (...) in its implementation"
def _validate_has_return_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is not None and return_type != type(None) and ":returns:" not in source:
return "does not have a ':returns:' in its docstring"
def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method)
sig = inspect.signature(method)
# Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring"
def _validate_has_no_return_none_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is None and ":returns: None" in source:
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
def _validate_docstring_lines_end_with_dot(method) -> str | None:
docstring = inspect.getdoc(method)
if docstring is None:
return None
lines = docstring.split('\n')
for line in lines:
line = line.strip()
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
_VALIDATORS = {
"GET": [
@ -186,13 +215,23 @@ _VALIDATORS = {
_validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list,
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_docstring_lines_end_with_dot,
],
"DELETE": [
_validate_api_delete_method_returns_none,
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring
],
"POST": [
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring,
_validate_docstring_lines_end_with_dot,
],
}

View file

@ -110,6 +110,8 @@ html_theme_options = {
"canonical_url": "https://github.com/meta-llama/llama-stack",
"collapse_navigation": False,
# "style_nav_header_background": "#c3c9d4",
'display_version': True,
'version_selector': True,
}
default_dark_mode = False

View file

@ -178,7 +178,7 @@ image_name: ollama
image_type: conda
# If some providers are external, you can specify the path to the implementation
external_providers_dir: /etc/llama-stack/providers.d
external_providers_dir: ~/.llama/providers.d
```
```
@ -206,7 +206,7 @@ distribution_spec:
image_type: container
image_name: ci-test
# Path to external provider implementations
external_providers_dir: /etc/llama-stack/providers.d
external_providers_dir: ~/.llama/providers.d
```
Here's an example for a custom Ollama provider:

View file

@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
external_providers_dir: ~/.llama/providers.d/
```
## Directory Structure
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
3. Create the provider specification:
```yaml
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
@ -201,7 +201,7 @@ uv pip install -e .
5. Configure Llama Stack to use external providers:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
external_providers_dir: ~/.llama/providers.d/
```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.

View file

@ -38,6 +38,67 @@ wait_for_service() {
return 0
}
usage() {
cat << EOF
📚 Llama-Stack Deployment Script
Description:
This script sets up and deploys Llama-Stack with Ollama integration in containers.
It handles both Docker and Podman runtimes and includes automatic platform detection.
Usage:
$(basename "$0") [OPTIONS]
Options:
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
-h, --help Show this help message
For more information:
Documentation: https://llama-stack.readthedocs.io/
GitHub: https://github.com/meta-llama/llama-stack
Report issues:
https://github.com/meta-llama/llama-stack/issues
EOF
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
usage
exit 0
;;
-p|--port)
PORT="$2"
shift 2
;;
-o|--ollama-port)
OLLAMA_PORT="$2"
shift 2
;;
-m|--model)
MODEL_ALIAS="$2"
shift 2
;;
-i|--image)
SERVER_IMAGE="$2"
shift 2
;;
-t|--timeout)
WAIT_TIMEOUT="$2"
shift 2
;;
*)
die "Unknown option: $1"
;;
esac
done
if command -v docker &> /dev/null; then
ENGINE="docker"
elif command -v podman &> /dev/null; then

View file

@ -413,7 +413,7 @@ class Agents(Protocol):
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@ -509,6 +509,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@ -606,5 +607,6 @@ class Agents(Protocol):
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -38,7 +38,17 @@ class BatchInference(Protocol):
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...
) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
@ -52,4 +62,17 @@ class BatchInference(Protocol):
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...
) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Benchmark: ...
) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None: ...
) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...

View file

@ -34,14 +34,21 @@ class DatasetIO(Protocol):
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
- data: List of items for the current page.
- has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
:returns: A PaginatedResponse.
"""
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -137,7 +137,8 @@ class Datasets(Protocol):
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
:param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
@ -188,8 +189,9 @@ class Datasets(Protocol):
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
- E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
"""
...
@ -197,13 +199,29 @@ class Datasets(Protocol):
async def get_dataset(
self,
dataset_id: str,
) -> Dataset: ...
) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...
) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -93,7 +93,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
:returns: The job that was created to run the evaluation.
"""
...
@ -111,7 +111,7 @@ class Eval(Protocol):
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
:returns: EvaluateResponse object containing generations and scores.
"""
...
@ -121,7 +121,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
:returns: The status of the evaluation job.
"""
...
@ -140,6 +140,6 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
:returns: The result of the job.
"""
...

View file

@ -91,10 +91,11 @@ class Files(Protocol):
"""
Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param mime_type: MIME type of the file
:param size: File size in bytes
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:param mime_type: MIME type of the file.
:param size: File size in bytes.
:returns: A FileUploadResponse.
"""
...
@ -107,7 +108,8 @@ class Files(Protocol):
Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session
:param upload_id: ID of the upload session.
:returns: A FileResponse or None if the upload is not complete.
"""
...
@ -117,9 +119,10 @@ class Files(Protocol):
upload_id: str,
) -> FileUploadResponse:
"""
Returns information about an existsing upload session
Returns information about an existsing upload session.
:param upload_id: ID of the upload session
:param upload_id: ID of the upload session.
:returns: A FileUploadResponse.
"""
...
@ -130,6 +133,9 @@ class Files(Protocol):
) -> ListBucketResponse:
"""
List all buckets.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListBucketResponse.
"""
...
@ -141,7 +147,8 @@ class Files(Protocol):
"""
List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListFileResponse.
"""
...
@ -154,8 +161,9 @@ class Files(Protocol):
"""
Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:returns: A FileResponse.
"""
...
@ -168,7 +176,7 @@ class Files(Protocol):
"""
Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
"""
...

View file

@ -845,13 +845,13 @@ class Inference(Protocol):
"""Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for
:param sampling_params: (Optional) Parameters to control the sampling strategy
:param response_format: (Optional) Grammar specification for guided (structured) decoding
:param content: The content to generate a completion for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
"""
...
@ -864,6 +864,15 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
@ -883,9 +892,9 @@ class Inference(Protocol):
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
@ -902,7 +911,7 @@ class Inference(Protocol):
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
"""
...
@ -917,6 +926,17 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
@ -935,7 +955,7 @@ class Inference(Protocol):
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
"""
...
@ -967,22 +987,23 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for
:param best_of: (Optional) The number of completions to generate
:param echo: (Optional) Whether to echo the prompt
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param presence_penalty: (Optional) The penalty for repeated tokens
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAICompletion.
"""
...
@ -1016,27 +1037,28 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param function_call: (Optional) The function call to use
:param functions: (Optional) List of functions to use
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
:param presence_penalty: (Optional) The penalty for repeated tokens
:param response_format: (Optional) The response format to use
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param tool_choice: (Optional) The tool choice to use
:param tools: (Optional) The tools to use
:param top_logprobs: (Optional) The top log probabilities to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
"""
...

View file

@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...
async def list_routes(self) -> ListRoutesResponse:
"""List all routes.
:returns: A ListRoutesResponse.
"""
...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...
async def health(self) -> HealthInfo:
"""Get the health of the service.
:returns: A HealthInfo.
"""
...
@webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ...
async def version(self) -> VersionInfo:
"""Get the version of the service.
:returns: A VersionInfo.
"""
...

View file

@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
) -> Model: ...
) -> Model:
"""Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST")
async def register_model(
@ -99,10 +115,25 @@ class Models(Protocol):
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model: ...
) -> Model:
"""Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model(
self,
model_id: str,
) -> None: ...
) -> None:
"""Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -182,7 +182,19 @@ class PostTraining(Protocol):
),
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ...
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
@ -193,16 +205,49 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -32,7 +32,18 @@ class Providers(Protocol):
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -54,4 +54,12 @@ class Safety(Protocol):
shield_id: str,
messages: list[Message],
params: dict[str, Any],
) -> RunShieldResponse: ...
) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...

View file

@ -61,7 +61,15 @@ class Scoring(Protocol):
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST")
async def score(
@ -73,6 +81,6 @@ class Scoring(Protocol):
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
:returns: A ScoreResponse object containing rows and aggregated results.
"""
...

View file

@ -134,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
@ -148,4 +159,14 @@ class ScoringFunctions(Protocol):
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None: ...
) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...

View file

@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ...
async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST")
async def register_shield(
@ -58,4 +69,13 @@ class Shields(Protocol):
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield: ...
) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...

View file

@ -247,7 +247,17 @@ class QueryMetricsResponse(BaseModel):
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(route="/telemetry/traces", method="POST")
async def query_traces(
@ -256,13 +266,35 @@ class Telemetry(Protocol):
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse: ...
) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree(
@ -270,7 +302,15 @@ class Telemetry(Protocol):
span_id: str,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse: ...
) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(route="/telemetry/spans", method="POST")
async def query_spans(
@ -278,7 +318,15 @@ class Telemetry(Protocol):
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse: ...
) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset(
@ -287,7 +335,15 @@ class Telemetry(Protocol):
attributes_to_save: list[str],
dataset_id: str,
max_depth: int | None = None,
) -> None: ...
) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
async def query_metrics(
@ -298,4 +354,15 @@ class Telemetry(Protocol):
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ...
) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -103,37 +103,65 @@ class ToolGroups(Protocol):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
"""Register a tool group"""
"""Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
) -> ToolGroup: ...
) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider"""
"""List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
...
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group"""
"""List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
"""
...
@webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
) -> Tool:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,
) -> None:
"""Unregister a tool group"""
"""Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
...
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ...
) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
"""Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
...

View file

@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB: ...
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB: ...
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -46,7 +46,14 @@ class VectorIO(Protocol):
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None: ...
) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param ttl_seconds: The time to live of the chunks.
"""
...
@webmethod(route="/vector-io/query", method="POST")
async def query_chunks(
@ -54,4 +61,12 @@ class VectorIO(Protocol):
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ...
) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...

View file

@ -36,7 +36,8 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
@ -202,7 +203,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
else:
with open(args.config) as f:
try:
build_config = BuildConfig(**yaml.safe_load(f))
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
except Exception as e:
cprint(
f"Could not parse config file {args.config}: {e}",
@ -248,6 +251,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_command(run_args)
@ -267,7 +272,9 @@ def _generate_run_config(
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
# build providers dict
provider_registry = get_provider_registry(build_config)

View file

@ -33,7 +33,8 @@ class StackRun(Subcommand):
self.parser.add_argument(
"config",
type=str,
help="Path to config file to use for the run",
nargs="?", # Make it optional
help="Path to config file to use for the run. Required for venv and conda environments.",
)
self.parser.add_argument(
"--port",
@ -82,44 +83,55 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not image_type and not image_name:
@ -141,7 +153,10 @@ class StackRun(Subcommand):
else:
run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
run_args.extend([str(args.port)])
if config_file:
run_args.extend(["--config", str(config_file)])
if args.env:
for env_var in args.env:

View file

@ -154,6 +154,12 @@ get_python_cmd() {
fi
}
# Add other required item commands generic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama/providers.d /.cache
EOF
if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
@ -166,17 +172,19 @@ EOF
# and update the configuration to reference the new container path
python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
if [ -n "$external_providers_dir" ]; then
external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY $external_providers_dir /app/providers.d
COPY providers.d /.llama/providers.d
EOF
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi
fi
fi
@ -255,9 +263,6 @@ fi
# Add other require item commands genearic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache
EOF

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
@ -170,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**config_dict)

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from pathlib import Path
from typing import Annotated, Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
@ -312,11 +313,20 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server",
)
external_providers_dir: str | None = Field(
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -145,7 +145,7 @@ def get_provider_registry(
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")

View file

@ -29,7 +29,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
exit 1
fi
@ -40,37 +40,51 @@ env_path_or_name="$1"
container_image="localhost/$env_path_or_name"
shift
yaml_config="$1"
shift
port="$1"
shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Initialize env_vars as an string
# Initialize variables
yaml_config=""
env_vars=""
other_args=""
# Process environment variables from --env arguments
# Process remaining arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
--config|--yaml-config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
else
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
exit 1
fi
;;
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Check if yaml_config is required based on env_type
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
exit 1
fi
PYTHON_BINARY="python"
case "$env_type" in
"venv")
@ -106,8 +120,14 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--yaml-config $yaml_config"
else
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
$yaml_config_arg \
--port "$port" \
$env_vars \
$other_args
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
# Build the command with optional yaml config
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
--env LLAMA_STACK_PORT=$port \
--entrypoint python \
$container_image:$version_tag \
-m llama_stack.distribution.server.server \
--yaml-config /app/config.yaml \
$other_args
-m llama_stack.distribution.server.server"
# Add yaml config if provided, otherwise use default
if [ -n "$yaml_config" ]; then
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
else
cmd="$cmd --yaml-config /app/run.yaml"
fi
# Add any other args
cmd="$cmd $other_args"
# Execute the command
eval $cmd
fi

View file

@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = ""
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image
if image_type == LlamaStackImageType.CONTAINER.value:
env_name = (
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
)
elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env

View file

@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_buf: UnparseableToolCall,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream(
else:
stop_reason = StopReason.end_of_message
if tool_call_buf.tool_name:
# at least one tool call request is received
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream(
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response(
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_buf=tool_call_buf,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response(
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c

View file

@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
"arguments": arguments,
},
}
)

View file

@ -152,46 +152,6 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dev": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
@ -642,6 +602,46 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"starter": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"aiohttp",
"aiosqlite",

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dev import get_distribution_template # noqa: F401
from .starter import get_distribution_template # noqa: F401

View file

@ -1,6 +1,6 @@
version: '2'
distribution_spec:
description: Distribution for running e2e tests in CI
description: Quick start template for running Llama Stack with several popular providers
providers:
inference:
- remote::openai

View file

@ -1,5 +1,5 @@
version: '2'
image_name: dev
image_name: starter
apis:
- agents
- datasetio
@ -46,7 +46,7 @@ providers:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
@ -71,14 +71,14 @@ providers:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/trace_store.db
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -86,7 +86,7 @@ providers:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/meta_reference_eval.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
@ -94,14 +94,14 @@ providers:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/huggingface_datasetio.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/localfs_datasetio.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
@ -132,7 +132,7 @@ providers:
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
models:
- metadata: {}
model_id: openai/gpt-4o

View file

@ -46,6 +46,7 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
@ -53,7 +54,7 @@ from llama_stack.templates.template import (
)
def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]:
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
@ -119,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::model-context-protocol",
],
}
name = "dev"
name = "starter"
vector_io_providers = [
Provider(
@ -171,7 +172,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running e2e tests in CI",
description="Quick start template for running Llama Stack with several popular providers",
container_image=None,
template_path=None,
providers=providers,

View file

@ -5,13 +5,13 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components.
## Getting Started
## Install the NPM dependencies
First, install dependencies:
```bash
npm install
npm install next react react-dom
```
Run the development server:
Then, run the development server:
```bash
npm run dev

View file

@ -45,10 +45,10 @@ export function AppSidebar() {
{logItems.map((item) => (
<SidebarMenuItem key={item.title}>
<SidebarMenuButton asChild>
<a href={item.url}>
<Link href={item.url}>
<item.icon />
<span>{item.title}</span>
</a>
</Link>
</SidebarMenuButton>
</SidebarMenuItem>
))}

View file

@ -304,7 +304,6 @@ exclude = [
"^llama_stack/strong_typing/inspection\\.py$",
"^llama_stack/strong_typing/schema\\.py$",
"^llama_stack/strong_typing/serializer\\.py$",
"^llama_stack/templates/dev/dev\\.py$",
"^llama_stack/templates/groq/groq\\.py$",
"^llama_stack/templates/llama_api/llama_api\\.py$",
"^llama_stack/templates/sambanova/sambanova\\.py$",

View file

@ -6,4 +6,4 @@ distribution_spec:
- remote::custom_ollama
image_type: container
image_name: ci-test
external_providers_dir: /tmp/providers.d
external_providers_dir: ~/.llama/providers.d

View file

@ -91,4 +91,4 @@ tool_groups:
provider_id: wolfram-alpha
server:
port: 8321
external_providers_dir: /tmp/providers.d
external_providers_dir: ~/.llama/providers.d

View file

@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
assert found_tool_execution
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
"content": "What is the boiling point of the liquid polyjuice in celsius?",
},
],
session_id=session_id,
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
"content": "What is the boiling point of the liquid polyjuice in celsius?",
},
],
session_id=session_id,
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
def test_multi_tool_calls(llama_stack_client, agent_config):
if "gpt" not in agent_config["model"]:
pytest.xfail("Only tested on GPT models")
if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
pytest.xfail("Only tested on GPT and Llama 4 models")
agent_config = {
**agent_config,
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
},
],
session_id=session_id,
stream=False,
)
steps = response.steps
assert len(steps) == 7
assert steps[0].step_type == "shield_call"
assert steps[1].step_type == "inference"
assert steps[2].step_type == "shield_call"
assert steps[3].step_type == "tool_execution"
assert steps[4].step_type == "shield_call"
assert steps[5].step_type == "inference"
assert steps[6].step_type == "shield_call"
tool_execution_step = steps[3]
has_input_shield = agent_config.get("input_shields")
has_output_shield = agent_config.get("output_shields")
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
if has_input_shield:
assert steps[0].step_type == "shield_call"
steps.pop(0)
assert steps[0].step_type == "inference"
if has_output_shield:
assert steps[1].step_type == "shield_call"
steps.pop(1)
assert steps[1].step_type == "tool_execution"
tool_execution_step = steps[1]
if has_input_shield:
assert steps[2].step_type == "shield_call"
steps.pop(2)
assert steps[2].step_type == "inference"
if has_output_shield:
assert steps[3].step_type == "shield_call"
steps.pop(3)
assert len(tool_execution_step.tool_calls) == 2
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")

View file

@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.stop_reason == StopReason.end_of_turn
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 0
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
def test_chat_completion_doesnt_block_event_loop(caplog):
@ -369,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -422,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -471,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name