mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Merge branch 'main' into small-ui-patches
This commit is contained in:
commit
e9cce9ed38
54 changed files with 1825 additions and 760 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning
|
||||
|
|
2
.github/TRIAGERS.md
vendored
2
.github/TRIAGERS.md
vendored
|
@ -1,2 +1,2 @@
|
|||
# This file documents Triage members in the Llama Stack community
|
||||
@franciscojavierarceo @leseb
|
||||
@bbrowning @booxter @franciscojavierarceo @leseb
|
||||
|
|
|
@ -47,8 +47,8 @@ jobs:
|
|||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /tmp/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||
mkdir -p /home/runner/.llama/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||
|
||||
- name: Build distro from config file
|
||||
run: |
|
||||
|
@ -66,7 +66,7 @@ jobs:
|
|||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
for i in {1..30}; do
|
||||
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
echo "Waiting for Llama Stack server to load the provider..."
|
||||
sleep 1
|
||||
else
|
||||
|
|
7
.github/workflows/update-readthedocs.yml
vendored
7
.github/workflows/update-readthedocs.yml
vendored
|
@ -14,6 +14,8 @@ on:
|
|||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
tags:
|
||||
- '*'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
@ -61,7 +63,10 @@ jobs:
|
|||
|
||||
response=$(curl -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"token\": \"$TOKEN\"}" \
|
||||
-d "{
|
||||
\"token\": \"$TOKEN\",
|
||||
\"version\": \"$GITHUB_REF_NAME\"
|
||||
}" \
|
||||
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
||||
|
||||
echo "Response: $response"
|
||||
|
|
647
docs/_static/llama-stack-spec.html
vendored
647
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
536
docs/_static/llama-stack-spec.yaml
vendored
536
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -179,6 +179,35 @@ def _validate_has_ellipsis(method) -> str | None:
|
|||
if "..." not in source and not "NotImplementedError" in source:
|
||||
return "does not contain ellipsis (...) in its implementation"
|
||||
|
||||
def _validate_has_return_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||
return "does not have a ':returns:' in its docstring"
|
||||
|
||||
def _validate_has_params_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
sig = inspect.signature(method)
|
||||
# Only check if the method has more than one parameter
|
||||
if len(sig.parameters) > 1 and ":param" not in source:
|
||||
return "does not have a ':param' in its docstring"
|
||||
|
||||
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is None and ":returns: None" in source:
|
||||
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||
|
||||
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||
docstring = inspect.getdoc(method)
|
||||
if docstring is None:
|
||||
return None
|
||||
|
||||
lines = docstring.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
|
@ -186,13 +215,23 @@ _VALIDATORS = {
|
|||
_validate_list_parameters_contain_data,
|
||||
_validate_api_method_doesnt_return_list,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring
|
||||
],
|
||||
"POST": [
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
@ -110,6 +110,8 @@ html_theme_options = {
|
|||
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
||||
"collapse_navigation": False,
|
||||
# "style_nav_header_background": "#c3c9d4",
|
||||
'display_version': True,
|
||||
'version_selector': True,
|
||||
}
|
||||
|
||||
default_dark_mode = False
|
||||
|
|
|
@ -178,7 +178,7 @@ image_name: ollama
|
|||
image_type: conda
|
||||
|
||||
# If some providers are external, you can specify the path to the implementation
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
```
|
||||
|
||||
```
|
||||
|
@ -206,7 +206,7 @@ distribution_spec:
|
|||
image_type: container
|
||||
image_name: ci-test
|
||||
# Path to external provider implementations
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
```
|
||||
|
||||
Here's an example for a custom Ollama provider:
|
||||
|
|
|
@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
|
|||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
external_providers_dir: ~/.llama/providers.d/
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
|||
3. Create the provider specification:
|
||||
|
||||
```yaml
|
||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
|
@ -201,7 +201,7 @@ uv pip install -e .
|
|||
5. Configure Llama Stack to use external providers:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
external_providers_dir: ~/.llama/providers.d/
|
||||
```
|
||||
|
||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||
|
|
61
install.sh
61
install.sh
|
@ -38,6 +38,67 @@ wait_for_service() {
|
|||
return 0
|
||||
}
|
||||
|
||||
usage() {
|
||||
cat << EOF
|
||||
📚 Llama-Stack Deployment Script
|
||||
|
||||
Description:
|
||||
This script sets up and deploys Llama-Stack with Ollama integration in containers.
|
||||
It handles both Docker and Podman runtimes and includes automatic platform detection.
|
||||
|
||||
Usage:
|
||||
$(basename "$0") [OPTIONS]
|
||||
|
||||
Options:
|
||||
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
|
||||
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
|
||||
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
|
||||
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
|
||||
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
|
||||
-h, --help Show this help message
|
||||
|
||||
For more information:
|
||||
Documentation: https://llama-stack.readthedocs.io/
|
||||
GitHub: https://github.com/meta-llama/llama-stack
|
||||
|
||||
Report issues:
|
||||
https://github.com/meta-llama/llama-stack/issues
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
-p|--port)
|
||||
PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-o|--ollama-port)
|
||||
OLLAMA_PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-m|--model)
|
||||
MODEL_ALIAS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-i|--image)
|
||||
SERVER_IMAGE="$2"
|
||||
shift 2
|
||||
;;
|
||||
-t|--timeout)
|
||||
WAIT_TIMEOUT="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
die "Unknown option: $1"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if command -v docker &> /dev/null; then
|
||||
ENGINE="docker"
|
||||
elif command -v podman &> /dev/null; then
|
||||
|
|
|
@ -413,7 +413,7 @@ class Agents(Protocol):
|
|||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -509,6 +509,7 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -606,5 +607,6 @@ class Agents(Protocol):
|
|||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -38,7 +38,17 @@ class BatchInference(Protocol):
|
|||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate completions for a batch of content.
|
||||
|
||||
:param model: The model to use for the completion.
|
||||
:param content_batch: The content to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param response_format: The response format to use for the completion.
|
||||
:param logprobs: The logprobs to use for the completion.
|
||||
:returns: A job for the completion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
|
@ -52,4 +62,17 @@ class BatchInference(Protocol):
|
|||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate chat completions for a batch of messages.
|
||||
|
||||
:param model: The model to use for the chat completion.
|
||||
:param messages_batch: The messages to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param tools: The tools to use for the chat completion.
|
||||
:param tool_choice: The tool choice to use for the chat completion.
|
||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||
:param response_format: The response format to use for the chat completion.
|
||||
:param logprobs: The logprobs to use for the chat completion.
|
||||
:returns: A job for the chat completion.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
:returns: A ListBenchmarksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Benchmark: ...
|
||||
) -> Benchmark:
|
||||
"""Get a benchmark by its ID.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to get.
|
||||
:returns: A Benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
|
|||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to register.
|
||||
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||
:param provider_id: The ID of the provider to use for the benchmark.
|
||||
:param metadata: The metadata to use for the benchmark.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -34,14 +34,21 @@ class DatasetIO(Protocol):
|
|||
- limit: Number of items to return. If None or -1, returns all items.
|
||||
|
||||
The response includes:
|
||||
- data: List of items for the current page
|
||||
- has_more: Whether there are more items available after this set
|
||||
- data: List of items for the current page.
|
||||
- has_more: Whether there are more items available after this set.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to append the rows to.
|
||||
:param rows: The rows to append to the dataset.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -137,7 +137,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset. One of
|
||||
:param purpose: The purpose of the dataset.
|
||||
One of:
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
|
@ -188,8 +189,9 @@ class Datasets(Protocol):
|
|||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}
|
||||
- E.g. {"description": "My dataset"}.
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -197,13 +199,29 @@ class Datasets(Protocol):
|
|||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset: ...
|
||||
) -> Dataset:
|
||||
"""Get a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
:returns: A ListDatasetsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -93,7 +93,7 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
:returns: The job that was created to run the evaluation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -111,7 +111,7 @@ class Eval(Protocol):
|
|||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
:returns: EvaluateResponse object containing generations and scores.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -121,7 +121,7 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
:returns: The status of the evaluation job.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -140,6 +140,6 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
:returns: The result of the job.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -91,10 +91,11 @@ class Files(Protocol):
|
|||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param size: File size in bytes
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:param mime_type: MIME type of the file.
|
||||
:param size: File size in bytes.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -107,7 +108,8 @@ class Files(Protocol):
|
|||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileResponse or None if the upload is not complete.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -117,9 +119,10 @@ class Files(Protocol):
|
|||
upload_id: str,
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
Returns information about an existsing upload session.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -130,6 +133,9 @@ class Files(Protocol):
|
|||
) -> ListBucketResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListBucketResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -141,7 +147,8 @@ class Files(Protocol):
|
|||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListFileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -154,8 +161,9 @@ class Files(Protocol):
|
|||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:returns: A FileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -168,7 +176,7 @@ class Files(Protocol):
|
|||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -845,13 +845,13 @@ class Inference(Protocol):
|
|||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content: The content to generate a completion for
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding
|
||||
:param content: The content to generate a completion for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -864,6 +864,15 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
"""Generate completions for a batch of content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content_batch: The content to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
|
@ -883,9 +892,9 @@ class Inference(Protocol):
|
|||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param sampling_params: Parameters to control the sampling strategy
|
||||
:param tools: (Optional) List of tool definitions available to the model
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
|
@ -902,7 +911,7 @@ class Inference(Protocol):
|
|||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -917,6 +926,17 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
"""Generate chat completions for a batch of messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages_batch: The messages to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
|
@ -935,7 +955,7 @@ class Inference(Protocol):
|
|||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -967,22 +987,23 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for
|
||||
:param best_of: (Optional) The number of completions to generate
|
||||
:param echo: (Optional) Whether to echo the prompt
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -1016,27 +1037,28 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param function_call: (Optional) The function call to use
|
||||
:param functions: (Optional) List of functions to use
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param response_format: (Optional) The response format to use
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param tool_choice: (Optional) The tool choice to use
|
||||
:param tools: (Optional) The tools to use
|
||||
:param top_logprobs: (Optional) The top log probabilities to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all routes.
|
||||
|
||||
:returns: A ListRoutesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo: ...
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the health of the service.
|
||||
|
||||
:returns: A HealthInfo.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET")
|
||||
async def version(self) -> VersionInfo: ...
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
|
||||
:returns: A VersionInfo.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
:returns: A ListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
@ -99,10 +115,25 @@ class Models(Protocol):
|
|||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param metadata: Any additional metadata for this model.
|
||||
:param model_type: The type of model to register.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -182,7 +182,19 @@ class PostTraining(Protocol):
|
|||
),
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run supervised fine-tuning of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:param model: The model to fine-tune.
|
||||
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
async def preference_optimize(
|
||||
|
@ -193,16 +205,49 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run preference optimization of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param finetuned_model: The model to fine-tune.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
:returns: A ListPostTrainingJobsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the status of.
|
||||
:returns: A PostTrainingJobStatusResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||
:returns: A PostTrainingJobArtifactsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -32,7 +32,18 @@ class Providers(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -54,4 +54,12 @@ class Safety(Protocol):
|
|||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse: ...
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -61,7 +61,15 @@ class Scoring(Protocol):
|
|||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
) -> ScoreBatchResponse:
|
||||
"""Score a batch of rows.
|
||||
|
||||
:param dataset_id: The ID of the dataset to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:param save_results_dataset: Whether to save the results to a dataset.
|
||||
:returns: A ScoreBatchResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
async def score(
|
||||
|
@ -73,6 +81,6 @@ class Scoring(Protocol):
|
|||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -134,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
:returns: A ListScoringFunctionsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to get.
|
||||
:returns: A ScoringFn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
@ -148,4 +159,14 @@ class ScoringFunctions(Protocol):
|
|||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to register.
|
||||
:param description: The description of the scoring function.
|
||||
:param return_type: The return type of the scoring function.
|
||||
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||
:param provider_id: The ID of the provider to use for the scoring function.
|
||||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
:returns: A ListShieldsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
:param identifier: The identifier of the shield to get.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
@ -58,4 +69,13 @@ class Shields(Protocol):
|
|||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield: ...
|
||||
) -> Shield:
|
||||
"""Register a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to register.
|
||||
:param provider_shield_id: The identifier of the shield in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -247,7 +247,17 @@ class QueryMetricsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
|
||||
) -> None:
|
||||
"""Log an event.
|
||||
|
||||
:param event: The event to log.
|
||||
:param ttl_seconds: The time to live of the event.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
|
@ -256,13 +266,35 @@ class Telemetry(Protocol):
|
|||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse: ...
|
||||
) -> QueryTracesResponse:
|
||||
"""Query traces.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the traces.
|
||||
:param limit: The limit of traces to return.
|
||||
:param offset: The offset of the traces to return.
|
||||
:param order_by: The order by of the traces to return.
|
||||
:returns: A QueryTracesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get.
|
||||
:returns: A Trace.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get the span from.
|
||||
:param span_id: The ID of the span to get.
|
||||
:returns: A Span.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
async def get_span_tree(
|
||||
|
@ -270,7 +302,15 @@ class Telemetry(Protocol):
|
|||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
) -> QuerySpanTreeResponse:
|
||||
"""Get a span tree by its ID.
|
||||
|
||||
:param span_id: The ID of the span to get the tree from.
|
||||
:param attributes_to_return: The attributes to return in the tree.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpanTreeResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
|
@ -278,7 +318,15 @@ class Telemetry(Protocol):
|
|||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse: ...
|
||||
) -> QuerySpansResponse:
|
||||
"""Query spans.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_return: The attributes to return in the spans.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpansResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
async def save_spans_to_dataset(
|
||||
|
@ -287,7 +335,15 @@ class Telemetry(Protocol):
|
|||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Save spans to a dataset.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_save: The attributes to save to the dataset.
|
||||
:param dataset_id: The ID of the dataset to save the spans to.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
async def query_metrics(
|
||||
|
@ -298,4 +354,15 @@ class Telemetry(Protocol):
|
|||
granularity: str | None = "1d",
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse: ...
|
||||
) -> QueryMetricsResponse:
|
||||
"""Query metrics.
|
||||
|
||||
:param metric_name: The name of the metric to query.
|
||||
:param start_time: The start time of the metric to query.
|
||||
:param end_time: The end time of the metric to query.
|
||||
:param granularity: The granularity of the metric to query.
|
||||
:param query_type: The type of query to perform.
|
||||
:param label_matchers: The label matchers to apply to the metric.
|
||||
:returns: A QueryMetricsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
"""Register a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to register.
|
||||
:param provider_id: The ID of the provider to use for the tool group.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:param args: A dictionary of arguments to pass to the tool group.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
) -> ToolGroup:
|
||||
"""Get a tool group by its ID.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to get.
|
||||
:returns: A ToolGroup.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider"""
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
:returns: A ListToolGroupsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
"""List tools with optional tool group"""
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool: ...
|
||||
) -> Tool:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> None:
|
||||
"""Unregister a tool group"""
|
||||
"""Unregister a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to unregister.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
|
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse: ...
|
||||
) -> ListToolDefsResponse:
|
||||
"""List all tools in the runtime.
|
||||
|
||||
:param tool_group_id: The ID of the tool group to list tools for.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
:param tool_name: The name of the tool to invoke.
|
||||
:param kwargs: A dictionary of arguments to pass to the tool.
|
||||
:returns: A ToolInvocationResult.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
:returns: A ListVectorDBsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Get a vector database by its identifier.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to get.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
|
|||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to register.
|
||||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,7 +46,14 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
async def query_chunks(
|
||||
|
@ -54,4 +61,12 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse: ...
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to query.
|
||||
:param query: The query to search for.
|
||||
:param params: The parameters of the query.
|
||||
:returns: A QueryChunksResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -36,7 +36,8 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
@ -202,7 +203,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
else:
|
||||
with open(args.config) as f:
|
||||
try:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
contents = yaml.safe_load(f)
|
||||
contents = replace_env_vars(contents)
|
||||
build_config = BuildConfig(**contents)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Could not parse config file {args.config}: {e}",
|
||||
|
@ -248,6 +251,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
run_config = Path(run_config)
|
||||
config_dict = yaml.safe_load(run_config.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||
run_command(run_args)
|
||||
|
@ -267,7 +272,9 @@ def _generate_run_config(
|
|||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||
external_providers_dir=build_config.external_providers_dir
|
||||
if build_config.external_providers_dir
|
||||
else EXTERNAL_PROVIDERS_DIR,
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
|
|
|
@ -33,7 +33,8 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"config",
|
||||
type=str,
|
||||
help="Path to config file to use for the run",
|
||||
nargs="?", # Make it optional
|
||||
help="Path to config file to use for the run. Required for venv and conda environments.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
|
@ -82,44 +83,55 @@ class StackRun(Subcommand):
|
|||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
except yaml.parser.ParserError as e:
|
||||
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
|
||||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
|
||||
self.parser.error("Config file is required for venv and conda environments")
|
||||
|
||||
if args.config:
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
except yaml.parser.ParserError as e:
|
||||
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
|
||||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
else:
|
||||
config = None
|
||||
config_file = None
|
||||
template_name = None
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
|
@ -141,7 +153,10 @@ class StackRun(Subcommand):
|
|||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
|
||||
run_args.extend([str(config_file), str(args.port)])
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
|
|
|
@ -154,6 +154,12 @@ get_python_cmd() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Add other required item commands generic to all containers
|
||||
add_to_container << EOF
|
||||
# Allows running as non-root user
|
||||
RUN mkdir -p /.llama/providers.d /.cache
|
||||
EOF
|
||||
|
||||
if [ -n "$run_config" ]; then
|
||||
# Copy the run config to the build context since it's an absolute path
|
||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
|
@ -166,17 +172,19 @@ EOF
|
|||
# and update the configuration to reference the new container path
|
||||
python_cmd=$(get_python_cmd)
|
||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||
if [ -n "$external_providers_dir" ]; then
|
||||
external_providers_dir=$(eval echo "$external_providers_dir")
|
||||
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
|
||||
echo "Copying external providers directory: $external_providers_dir"
|
||||
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
|
||||
add_to_container << EOF
|
||||
COPY $external_providers_dir /app/providers.d
|
||||
COPY providers.d /.llama/providers.d
|
||||
EOF
|
||||
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||
else
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
@ -255,9 +263,6 @@ fi
|
|||
# Add other require item commands genearic to all containers
|
||||
add_to_container << EOF
|
||||
|
||||
# Allows running as non-root user
|
||||
RUN mkdir -p /.llama /.cache
|
||||
|
||||
RUN chmod -R g+rw /app /.llama /.cache
|
||||
EOF
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
@ -170,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -312,11 +313,20 @@ a default SQLite store will be used.""",
|
|||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: str | None = Field(
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -145,7 +145,7 @@ def get_provider_registry(
|
|||
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
|
|
@ -29,7 +29,7 @@ error_handler() {
|
|||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -40,37 +40,51 @@ env_path_or_name="$1"
|
|||
container_image="localhost/$env_path_or_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Initialize env_vars as an string
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
env_vars=""
|
||||
other_args=""
|
||||
# Process environment variables from --env arguments
|
||||
|
||||
# Process remaining arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
--config|--yaml-config)
|
||||
if [[ -n "$2" ]]; then
|
||||
yaml_config="$2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--env)
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if yaml_config is required based on env_type
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
|
||||
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_BINARY="python"
|
||||
case "$env_type" in
|
||||
"venv")
|
||||
|
@ -106,8 +120,14 @@ esac
|
|||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--yaml-config $yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
|
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
# Build the command with optional yaml config
|
||||
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$mounts \
|
||||
--env LLAMA_STACK_PORT=$port \
|
||||
--entrypoint python \
|
||||
$container_image:$version_tag \
|
||||
-m llama_stack.distribution.server.server \
|
||||
--yaml-config /app/config.yaml \
|
||||
$other_args
|
||||
-m llama_stack.distribution.server.server"
|
||||
|
||||
# Add yaml config if provided, otherwise use default
|
||||
if [ -n "$yaml_config" ]; then
|
||||
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
|
||||
else
|
||||
cmd="$cmd --yaml-config /app/run.yaml"
|
||||
fi
|
||||
|
||||
# Add any other args
|
||||
cmd="$cmd $other_args"
|
||||
|
||||
# Execute the command
|
||||
eval $cmd
|
||||
fi
|
||||
|
|
|
@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
||||
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
|
||||
|
|
|
@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
|||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
if image_type == LlamaStackImageType.CONTAINER.value:
|
||||
env_name = (
|
||||
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
|
||||
)
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
|
|
|
@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
finish_reason: str | None,
|
||||
last_chunk_content: str | None,
|
||||
current_event_type: ChatCompletionResponseEventType,
|
||||
tool_call_buf: UnparseableToolCall,
|
||||
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
|
||||
) -> list[OpenAIChatCompletionChunk]:
|
||||
chunks = []
|
||||
|
||||
|
@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
else:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if tool_call_buf.tool_name:
|
||||
# at least one tool call request is received
|
||||
|
||||
tool_call_bufs = tool_call_bufs or {}
|
||||
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
|
||||
args_str = tool_call_buf.arguments or "{}"
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
|
@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
async def _process_vllm_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
|
||||
) -> AsyncGenerator:
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
tool_call_buf = UnparseableToolCall()
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
||||
tool_call_bufs: dict[str, UnparseableToolCall] = {}
|
||||
end_of_stream_processed = False
|
||||
|
||||
async for chunk in stream:
|
||||
|
@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
return
|
||||
choice = chunk.choices[0]
|
||||
if choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
# TODO: remove str() when dict type for 'arguments' is no longer allowed
|
||||
tool_call_buf.arguments += str(tool_call.arguments)
|
||||
for delta_tool_call in choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(delta_tool_call)
|
||||
if delta_tool_call.index not in tool_call_bufs:
|
||||
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
|
||||
tool_call_buf = tool_call_bufs[delta_tool_call.index]
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
tool_call_buf.arguments += (
|
||||
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
|
||||
)
|
||||
if choice.finish_reason:
|
||||
chunks = _process_vllm_chat_completion_end_of_stream(
|
||||
finish_reason=choice.finish_reason,
|
||||
last_chunk_content=choice.delta.content,
|
||||
current_event_type=event_type,
|
||||
tool_call_buf=tool_call_buf,
|
||||
tool_call_bufs=tool_call_bufs,
|
||||
)
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
# the stream ended without a chunk containing finish_reason - we have to generate the
|
||||
# respective completion chunks manually
|
||||
chunks = _process_vllm_chat_completion_end_of_stream(
|
||||
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf
|
||||
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
|
||||
)
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
|
|
@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
tool_name = tc.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
# arguments_json can be None, so attempt it first and fall back to arguments
|
||||
if hasattr(tc, "arguments_json") and tc.arguments_json:
|
||||
arguments = tc.arguments_json
|
||||
else:
|
||||
arguments = json.dumps(tc.arguments)
|
||||
result["tool_calls"].append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
|
@ -152,46 +152,6 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"dev": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"fireworks": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
|
@ -642,6 +602,46 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"starter": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"tgi": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .dev import get_distribution_template # noqa: F401
|
||||
from .starter import get_distribution_template # noqa: F401
|
|
@ -1,6 +1,6 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Distribution for running e2e tests in CI
|
||||
description: Quick start template for running Llama Stack with several popular providers
|
||||
providers:
|
||||
inference:
|
||||
- remote::openai
|
|
@ -1,5 +1,5 @@
|
|||
version: '2'
|
||||
image_name: dev
|
||||
image_name: starter
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
|
@ -46,7 +46,7 @@ providers:
|
|||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/sqlite_vec.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
|
@ -71,14 +71,14 @@ providers:
|
|||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/trace_store.db
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -86,7 +86,7 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/meta_reference_eval.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
|
@ -94,14 +94,14 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/huggingface_datasetio.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/localfs_datasetio.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
@ -132,7 +132,7 @@ providers:
|
|||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
|
@ -46,6 +46,7 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
|
@ -53,7 +54,7 @@ from llama_stack.templates.template import (
|
|||
)
|
||||
|
||||
|
||||
def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]:
|
||||
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
|
@ -119,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
name = "dev"
|
||||
name = "starter"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
|
@ -171,7 +172,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Distribution for running e2e tests in CI",
|
||||
description="Quick start template for running Llama Stack with several popular providers",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
|
@ -5,13 +5,13 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components.
|
|||
|
||||
## Getting Started
|
||||
|
||||
## Install the NPM dependencies
|
||||
First, install dependencies:
|
||||
|
||||
```bash
|
||||
npm install
|
||||
npm install next react react-dom
|
||||
```
|
||||
|
||||
Run the development server:
|
||||
Then, run the development server:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
|
|
|
@ -45,10 +45,10 @@ export function AppSidebar() {
|
|||
{logItems.map((item) => (
|
||||
<SidebarMenuItem key={item.title}>
|
||||
<SidebarMenuButton asChild>
|
||||
<a href={item.url}>
|
||||
<Link href={item.url}>
|
||||
<item.icon />
|
||||
<span>{item.title}</span>
|
||||
</a>
|
||||
</Link>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
))}
|
||||
|
|
|
@ -304,7 +304,6 @@ exclude = [
|
|||
"^llama_stack/strong_typing/inspection\\.py$",
|
||||
"^llama_stack/strong_typing/schema\\.py$",
|
||||
"^llama_stack/strong_typing/serializer\\.py$",
|
||||
"^llama_stack/templates/dev/dev\\.py$",
|
||||
"^llama_stack/templates/groq/groq\\.py$",
|
||||
"^llama_stack/templates/llama_api/llama_api\\.py$",
|
||||
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
||||
|
|
|
@ -6,4 +6,4 @@ distribution_spec:
|
|||
- remote::custom_ollama
|
||||
image_type: container
|
||||
image_name: ci-test
|
||||
external_providers_dir: /tmp/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
|
|
|
@ -91,4 +91,4 @@ tool_groups:
|
|||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
external_providers_dir: /tmp/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
|
|
|
@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
|||
assert found_tool_execution
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
"content": "What is the boiling point of the liquid polyjuice in celsius?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
"content": "What is the boiling point of the liquid polyjuice in celsius?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
|
|||
|
||||
|
||||
def test_multi_tool_calls(llama_stack_client, agent_config):
|
||||
if "gpt" not in agent_config["model"]:
|
||||
pytest.xfail("Only tested on GPT models")
|
||||
if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
|
||||
pytest.xfail("Only tested on GPT and Llama 4 models")
|
||||
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?",
|
||||
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
steps = response.steps
|
||||
assert len(steps) == 7
|
||||
assert steps[0].step_type == "shield_call"
|
||||
assert steps[1].step_type == "inference"
|
||||
assert steps[2].step_type == "shield_call"
|
||||
assert steps[3].step_type == "tool_execution"
|
||||
assert steps[4].step_type == "shield_call"
|
||||
assert steps[5].step_type == "inference"
|
||||
assert steps[6].step_type == "shield_call"
|
||||
|
||||
tool_execution_step = steps[3]
|
||||
has_input_shield = agent_config.get("input_shields")
|
||||
has_output_shield = agent_config.get("output_shields")
|
||||
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
|
||||
if has_input_shield:
|
||||
assert steps[0].step_type == "shield_call"
|
||||
steps.pop(0)
|
||||
assert steps[0].step_type == "inference"
|
||||
if has_output_shield:
|
||||
assert steps[1].step_type == "shield_call"
|
||||
steps.pop(1)
|
||||
assert steps[1].step_type == "tool_execution"
|
||||
tool_execution_step = steps[1]
|
||||
if has_input_shield:
|
||||
assert steps[2].step_type == "shield_call"
|
||||
steps.pop(2)
|
||||
assert steps[2].step_type == "inference"
|
||||
if has_output_shield:
|
||||
assert steps[3].step_type == "shield_call"
|
||||
steps.pop(3)
|
||||
|
||||
assert len(tool_execution_step.tool_calls) == 2
|
||||
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")
|
||||
|
|
|
@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].event.stop_reason == StopReason.end_of_turn
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "complete"
|
||||
assert chunks[1].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_delta_streaming_arguments_dict():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="tc_1",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_2 = OpenAIChatCompletionChunk(
|
||||
id="chunk-2",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="tc_1",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments='{"number": 28, "power": 3}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_3 = OpenAIChatCompletionChunk(
|
||||
id="chunk-3",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "complete"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments='{"number": 28, "power": 3}',
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_2 = OpenAIChatCompletionChunk(
|
||||
id="chunk-2",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="",
|
||||
index=2,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="multiple",
|
||||
arguments='{"first_number": 4, "second_number": 7}',
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_3 = OpenAIChatCompletionChunk(
|
||||
id="chunk-3",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 4
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "progress"
|
||||
assert chunks[2].event.delta.type == "tool_call"
|
||||
assert chunks[2].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
|
||||
assert chunks[3].event.event_type.value == "complete"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 0
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
|
||||
|
||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||
|
@ -369,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
@ -422,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
@ -471,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue