mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
Merge branch 'main' into testpypi-workflow
This commit is contained in:
commit
f9c309d05c
141 changed files with 6551 additions and 3032 deletions
|
@ -23,6 +23,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -54,6 +55,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -86,6 +88,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -116,6 +119,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -148,6 +152,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -181,6 +186,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -213,6 +219,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -247,6 +254,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
|
@ -286,6 +294,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
|
@ -319,6 +328,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -352,6 +362,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -385,6 +396,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
|
|
@ -85,7 +85,7 @@ services:
|
||||||
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
||||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||||
ports:
|
ports:
|
||||||
- "${LLAMASTACK_PORT:-5001}:${LLAMASTACK_PORT:-5001}"
|
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
||||||
# Hack: wait for vLLM server to start before starting docker
|
# Hack: wait for vLLM server to start before starting docker
|
||||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
||||||
deploy:
|
deploy:
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -486,13 +486,22 @@ class Generator:
|
||||||
parameters = path_parameters + query_parameters
|
parameters = path_parameters + query_parameters
|
||||||
parameters += [
|
parameters += [
|
||||||
Parameter(
|
Parameter(
|
||||||
name="X-LlamaStack-ProviderData",
|
name="X-LlamaStack-Provider-Data",
|
||||||
in_=ParameterLocation.Header,
|
in_=ParameterLocation.Header,
|
||||||
description="JSON-encoded provider data which will be made available to the adapter servicing the API",
|
description="JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
required=False,
|
required=False,
|
||||||
schema=self.schema_builder.classdef_to_ref(str),
|
schema=self.schema_builder.classdef_to_ref(str),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
parameters += [
|
||||||
|
Parameter(
|
||||||
|
name="X-LlamaStack-Client-Version",
|
||||||
|
in_=ParameterLocation.Header,
|
||||||
|
description="Version of the client making the request. This is used to ensure that the client and server are compatible.",
|
||||||
|
required=False,
|
||||||
|
schema=self.schema_builder.classdef_to_ref(str),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# data passed in payload
|
# data passed in payload
|
||||||
if op.request_params:
|
if op.request_params:
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
| safety | `remote::bedrock` |
|
| safety | `remote::bedrock` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
|
|
|
@ -9,13 +9,14 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| memory | `inline::meta-reference` |
|
| memory | `inline::meta-reference` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
|
@ -22,13 +22,14 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
||||||
|
@ -30,7 +31,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
||||||
|
@ -32,7 +33,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
|
|
||||||
|
|
|
@ -22,13 +22,14 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
|
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
||||||
|
@ -26,9 +27,9 @@ You can use this distribution if you have GPUs and want to run an independent vL
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100}/v1`)
|
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
|
||||||
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
||||||
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
|
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
|
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
|
||||||
|
@ -31,7 +32,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
|
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
|
||||||
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
||||||
|
|
|
@ -22,13 +22,14 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
|
@ -89,7 +89,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
```
|
```
|
||||||
...
|
...
|
||||||
Build Successful! Next steps:
|
Build Successful! Next steps:
|
||||||
1. Set the environment variables: LLAMASTACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
||||||
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -18,15 +18,11 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class AgentTool(Enum):
|
class Document(BaseModel):
|
||||||
brave_search = "brave_search"
|
content: InterleavedContent | URL
|
||||||
wolfram_alpha = "wolfram_alpha"
|
mime_type: str
|
||||||
photogen = "photogen"
|
|
||||||
code_interpreter = "code_interpreter"
|
|
||||||
|
|
||||||
function_call = "function_call"
|
|
||||||
memory = "memory"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolDefinitionCommon(BaseModel):
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchEngineType(Enum):
|
|
||||||
bing = "bing"
|
|
||||||
brave = "brave"
|
|
||||||
tavily = "tavily"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SearchToolDefinition(ToolDefinitionCommon):
|
|
||||||
# NOTE: brave_search is just a placeholder since model always uses
|
|
||||||
# brave_search as tool call name
|
|
||||||
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
|
||||||
api_key: str
|
|
||||||
engine: SearchEngineType = SearchEngineType.brave
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
|
||||||
api_key: str
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
|
||||||
enable_inline_code_execution: bool = True
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
|
||||||
function_name: str
|
|
||||||
description: str
|
|
||||||
parameters: Dict[str, ToolParamDefinition]
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["vector"] = "vector"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyvalue"] = "keyvalue"
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyword"] = "keyword"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
AgentVectorMemoryBankConfig,
|
|
||||||
AgentKeyValueMemoryBankConfig,
|
|
||||||
AgentKeywordMemoryBankConfig,
|
|
||||||
AgentGraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryGenerator(Enum):
|
|
||||||
default = "default"
|
|
||||||
llm = "llm"
|
|
||||||
custom = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
||||||
MemoryQueryGenerator.default.value
|
|
||||||
)
|
|
||||||
sep: str = " "
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
||||||
model: str
|
|
||||||
template: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryQueryGeneratorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
CustomMemoryQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
||||||
# This config defines how a query is generated using the messages
|
|
||||||
# for memory bank retrieval.
|
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
|
||||||
max_chunks: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
AgentToolDefinition = Annotated[
|
|
||||||
Union[
|
|
||||||
SearchToolDefinition,
|
|
||||||
WolframAlphaToolDefinition,
|
|
||||||
PhotogenToolDefinition,
|
|
||||||
CodeInterpreterToolDefinition,
|
|
||||||
FunctionCallToolDefinition,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
class StepCommon(BaseModel):
|
||||||
|
@ -289,13 +137,27 @@ class Session(BaseModel):
|
||||||
memory_bank: Optional[MemoryBank] = None
|
memory_bank: Optional[MemoryBank] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
|
name: str
|
||||||
|
args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
AgentToolGroup = register_schema(
|
||||||
|
Union[
|
||||||
|
str,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
|
],
|
||||||
|
name="AgentTool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
AgentTurnResponseEventType.step_complete.value
|
AgentTurnResponseEventType.step_complete.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
|
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
attachments: Optional[List[Attachment]] = None
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
@ -450,8 +315,9 @@ class Agents(Protocol):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +72,6 @@ class BatchInference(Protocol):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> BatchChatCompletionResponse: ...
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,7 +325,7 @@ class Inference(Protocol):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
@ -21,59 +21,48 @@ class ToolParameter(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
parameter_type: str
|
parameter_type: str
|
||||||
description: str
|
description: str
|
||||||
|
required: bool = Field(default=True)
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolHost(Enum):
|
||||||
|
distribution = "distribution"
|
||||||
|
client = "client"
|
||||||
|
model_context_protocol = "model_context_protocol"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||||
tool_group: str
|
toolgroup_id: str
|
||||||
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
provider_id: Optional[str] = None
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: Optional[str] = None
|
||||||
parameters: List[ToolParameter]
|
parameters: Optional[List[ToolParameter]] = None
|
||||||
metadata: Dict[str, Any]
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MCPToolGroupDef(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
"""
|
toolgroup_id: str
|
||||||
A tool group that is defined by in a model context protocol server.
|
provider_id: str
|
||||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
args: Optional[Dict[str, Any]] = None
|
||||||
"""
|
mcp_endpoint: Optional[URL] = None
|
||||||
|
|
||||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
|
||||||
endpoint: URL
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UserDefinedToolGroupDef(BaseModel):
|
|
||||||
type: Literal["user_defined"] = "user_defined"
|
|
||||||
tools: List[ToolDef]
|
|
||||||
|
|
||||||
|
|
||||||
ToolGroupDef = register_schema(
|
|
||||||
Annotated[
|
|
||||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
|
||||||
],
|
|
||||||
name="ToolGroup",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
|
mcp_endpoint: Optional[URL] = None
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -85,6 +74,7 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
def get_tool(self, tool_name: str) -> Tool: ...
|
def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
|
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -93,9 +83,10 @@ class ToolGroups(Protocol):
|
||||||
@webmethod(route="/toolgroups/register", method="POST")
|
@webmethod(route="/toolgroups/register", method="POST")
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
tool_group: ToolGroupDef,
|
provider_id: str,
|
||||||
provider_id: Optional[str] = None,
|
mcp_endpoint: Optional[URL] = None,
|
||||||
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
@ -103,7 +94,7 @@ class ToolGroups(Protocol):
|
||||||
@webmethod(route="/toolgroups/get", method="GET")
|
@webmethod(route="/toolgroups/get", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
) -> ToolGroup: ...
|
) -> ToolGroup: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/list", method="GET")
|
@webmethod(route="/toolgroups/list", method="GET")
|
||||||
|
@ -130,8 +121,11 @@ class ToolGroups(Protocol):
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
@ -34,7 +35,7 @@ class StackRun(Subcommand):
|
||||||
"--port",
|
"--port",
|
||||||
type=int,
|
type=int,
|
||||||
help="Port to run the server on. Defaults to 5000",
|
help="Port to run the server on. Defaults to 5000",
|
||||||
default=5000,
|
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--disable-ipv6",
|
"--disable-ipv6",
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
|
|
||||||
|
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||||
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
|
|
|
@ -267,6 +267,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config, self.custom_provider_registry
|
self.config, self.custom_provider_registry
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
|
cprint(_e.msg, "red")
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
|
|
|
@ -40,8 +40,8 @@ class NeedsRequestProviderData:
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str]):
|
def set_request_provider_data(headers: Dict[str, str]):
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-ProviderData",
|
"X-LlamaStack-Provider-Data",
|
||||||
"x-llamastack-providerdata",
|
"x-llamastack-provider-data",
|
||||||
]
|
]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
val = headers.get(key, None)
|
val = headers.get(key, None)
|
||||||
|
|
|
@ -5,9 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
|
@ -38,7 +35,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
|
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ class InferenceRouter(Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
async def list_runtime_tools(
|
||||||
return await self.routing_table.get_provider_impl(
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
tool_group.name
|
) -> List[ToolDef]:
|
||||||
).discover_tools(tool_group)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||||
|
tool_group_id, mcp_endpoint
|
||||||
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import parse_obj_as
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
@ -26,20 +26,12 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import Shield, Shields
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
||||||
MCPToolGroupDef,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
ToolGroupDef,
|
|
||||||
ToolGroups,
|
|
||||||
UserDefinedToolGroupDef,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
@ -361,7 +353,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
"embedding_dimension"
|
"embedding_dimension"
|
||||||
]
|
]
|
||||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
@ -496,54 +488,44 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||||
tools = await self.get_all_with_type("tool")
|
tools = await self.get_all_with_type("tool")
|
||||||
if tool_group_id:
|
if tool_group_id:
|
||||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||||
return await self.get_all_with_type("tool_group")
|
return await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
return await self.get_object_by_identifier("tool", tool_name)
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
tool_group: ToolGroupDef,
|
provider_id: str,
|
||||||
provider_id: Optional[str] = None,
|
mcp_endpoint: Optional[URL] = None,
|
||||||
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = []
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||||
if provider_id is None:
|
toolgroup_id, mcp_endpoint
|
||||||
if len(self.impls_by_provider_id.keys()) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
|
||||||
)
|
)
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
tool_host = (
|
||||||
|
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
if isinstance(tool_group, MCPToolGroupDef):
|
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
|
||||||
tool_group
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
|
||||||
tool_defs = tool_group.tools
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
identifier=tool_def.name,
|
identifier=tool_def.name,
|
||||||
tool_group=tool_group_id,
|
toolgroup_id=toolgroup_id,
|
||||||
description=tool_def.description,
|
description=tool_def.description or "",
|
||||||
parameters=tool_def.parameters,
|
parameters=tool_def.parameters or [],
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
tool_prompt_format=tool_def.tool_prompt_format,
|
|
||||||
provider_resource_id=tool_def.name,
|
provider_resource_id=tool_def.name,
|
||||||
metadata=tool_def.metadata,
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -561,9 +543,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
await self.dist_registry.register(
|
await self.dist_registry.register(
|
||||||
ToolGroup(
|
ToolGroup(
|
||||||
identifier=tool_group_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=tool_group_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
@ -228,6 +230,52 @@ class TracingMiddleware:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
|
||||||
|
class ClientVersionMiddleware:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
self.server_version = parse_version("llama-stack")
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||||
|
if client_version:
|
||||||
|
try:
|
||||||
|
client_version_parts = tuple(
|
||||||
|
map(int, client_version.split(".")[:2])
|
||||||
|
)
|
||||||
|
server_version_parts = tuple(
|
||||||
|
map(int, self.server_version.split(".")[:2])
|
||||||
|
)
|
||||||
|
if client_version_parts != server_version_parts:
|
||||||
|
|
||||||
|
async def send_version_error(send):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 426,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
error_msg = json.dumps(
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).encode()
|
||||||
|
await send(
|
||||||
|
{"type": "http.response.body", "body": error_msg}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await send_version_error(send)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# If version parsing fails, let the request through
|
||||||
|
pass
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
@ -242,7 +290,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
type=int,
|
type=int,
|
||||||
default=int(os.getenv("LLAMASTACK_PORT", 5000)),
|
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
|
||||||
help="Port to listen on",
|
help="Port to listen on",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -291,6 +339,7 @@ def main():
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -33,14 +32,13 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
LLAMA_STACK_API_VERSION = "alpha"
|
LLAMA_STACK_API_VERSION = "alpha"
|
||||||
|
@ -65,6 +63,8 @@ class LlamaStack(
|
||||||
Models,
|
Models,
|
||||||
Shields,
|
Shields,
|
||||||
Inspect,
|
Inspect,
|
||||||
|
ToolGroups,
|
||||||
|
ToolRuntime,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ RESOURCES = [
|
||||||
"list_scoring_functions",
|
"list_scoring_functions",
|
||||||
),
|
),
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
-v "$yaml_config:/app/config.yaml" \
|
-v "$yaml_config:/app/config.yaml" \
|
||||||
$mounts \
|
$mounts \
|
||||||
--env LLAMASTACK_PORT=$port \
|
--env LLAMA_STACK_PORT=$port \
|
||||||
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
|
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
|
||||||
$docker_image:$version_tag
|
$docker_image:$version_tag
|
||||||
|
|
|
@ -12,7 +12,6 @@ import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v3"
|
KEY_VERSION = "v4"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@ async def get_provider_impl(
|
||||||
deps[Api.memory],
|
deps[Api.memory],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.memory_banks],
|
deps[Api.memory_banks],
|
||||||
|
deps[Api.tool_runtime],
|
||||||
|
deps[Api.tool_groups],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -13,16 +13,16 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTool,
|
AgentToolGroup,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseEvent,
|
AgentTurnResponseEvent,
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
|
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
Attachment,
|
Attachment,
|
||||||
CodeInterpreterToolDefinition,
|
Document,
|
||||||
FunctionCallToolDefinition,
|
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
MemoryRetrievalStep,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
PhotogenToolDefinition,
|
|
||||||
SearchToolDefinition,
|
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
WolframAlphaToolDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
TextContentItem,
|
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
ToolChoice,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .rag.context_retriever import generate_rag_query
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .tools.base import BaseTool
|
|
||||||
from .tools.builtin import (
|
|
||||||
CodeInterpreterTool,
|
|
||||||
interpret_content_as_attachment,
|
|
||||||
PhotogenTool,
|
|
||||||
SearchTool,
|
|
||||||
WolframAlphaTool,
|
|
||||||
)
|
|
||||||
from .tools.safety import SafeTool
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
MEMORY_QUERY_TOOL = "query_memory"
|
||||||
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
MEMORY_GROUP = "builtin::memory"
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
builtin_tools = []
|
self.tool_groups_api = tool_groups_api
|
||||||
for tool_defn in agent_config.tools:
|
|
||||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
|
||||||
tool = WolframAlphaTool(tool_defn.api_key)
|
|
||||||
elif isinstance(tool_defn, SearchToolDefinition):
|
|
||||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
|
||||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
|
||||||
tool = PhotogenTool(dump_dir=self.tempdir)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
builtin_tools.append(
|
|
||||||
SafeTool(
|
|
||||||
tool,
|
|
||||||
safety_api,
|
|
||||||
tool_defn.input_shields,
|
|
||||||
tool_defn.output_shields,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
attachments=request.attachments or [],
|
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
documents=request.documents,
|
||||||
|
toolgroups_for_turn=request.toolgroups,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
log.info(
|
log.info(
|
||||||
|
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
input_messages,
|
||||||
|
sampling_params,
|
||||||
|
stream,
|
||||||
|
documents,
|
||||||
|
toolgroups_for_turn,
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
toolgroup_args = {}
|
||||||
need_rag_context = await self._should_retrieve_context(
|
for toolgroup in self.agent_config.toolgroups:
|
||||||
input_messages, attachments
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
if toolgroups_for_turn:
|
||||||
|
for toolgroup in toolgroups_for_turn:
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
|
||||||
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
|
if documents:
|
||||||
|
await self.handle_documents(
|
||||||
|
session_id, documents, input_messages, tool_defs
|
||||||
)
|
)
|
||||||
if need_rag_context:
|
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||||
|
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
if memory_tool_group is None:
|
||||||
|
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||||
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.memory_retrieval.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
query_args = {
|
||||||
|
"messages": [msg.content for msg in input_messages],
|
||||||
|
**toolgroup_args.get(memory_tool_group, {}),
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: find older context from the session and either replace it
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# or append with a sliding window. this is really a very simplistic implementation
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
with tracing.span("retrieve_rag_context") as span:
|
if session_info.memory_bank_id:
|
||||||
rag_context, bank_ids = await self._retrieve_context(
|
if "memory_bank_ids" not in query_args:
|
||||||
session_id, input_messages, attachments
|
query_args["memory_bank_ids"] = []
|
||||||
)
|
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||||
span.set_attribute(
|
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
|
||||||
)
|
|
||||||
span.set_attribute("output", rag_context)
|
|
||||||
span.set_attribute("bank_ids", bank_ids)
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.memory_retrieval.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=MemoryRetrievalStep(
|
tool_call_delta=ToolCallDelta(
|
||||||
turn_id=turn_id,
|
parse_status=ToolCallParseStatus.success,
|
||||||
step_id=step_id,
|
content=ToolCall(
|
||||||
memory_bank_ids=bank_ids,
|
call_id="",
|
||||||
inserted_context=rag_context or "",
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
arguments={},
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
args=query_args,
|
||||||
|
)
|
||||||
|
|
||||||
if rag_context:
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id="",
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span.set_attribute(
|
||||||
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
|
)
|
||||||
|
span.set_attribute("output", result.content)
|
||||||
|
span.set_attribute("error_code", result.error_code)
|
||||||
|
span.set_attribute("error_message", result.error_message)
|
||||||
|
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||||
|
if result.error_code == 0:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = rag_context
|
last_message.context = result.content
|
||||||
|
|
||||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
|
||||||
# TODO: we need to migrate URL away from str type
|
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
|
||||||
urls += [
|
|
||||||
URL(uri=a.content) for a in attachments if pattern.match(a.content)
|
|
||||||
]
|
|
||||||
msg = await attachment_message(self.tempdir, urls)
|
|
||||||
input_messages.append(msg)
|
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
|
client_tools = {}
|
||||||
|
for tool in self.agent_config.client_tools:
|
||||||
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
msg = input_messages[-1]
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=[
|
||||||
|
tool
|
||||||
|
for tool in tool_defs.values()
|
||||||
|
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||||
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += attachments
|
message.content += output_attachments
|
||||||
else:
|
else:
|
||||||
message.content = [message.content] + attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
log.info(f"Partial message: {str(message)}")
|
log.info(f"Partial message: {str(message)}")
|
||||||
|
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
name = tool_call.tool_name
|
|
||||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_name = tool_call.tool_name
|
||||||
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
tool_name = tool_name.value
|
||||||
with tracing.span(
|
with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_call.tool_name,
|
"tool_name": tool_name,
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
result_messages = await execute_tool_call_maybe(
|
result_messages = await execute_tool_call_maybe(
|
||||||
self.tools_dict,
|
self.tool_runtime_api,
|
||||||
|
session_id,
|
||||||
[message],
|
[message],
|
||||||
|
toolgroup_args,
|
||||||
|
tool_to_group,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
len(result_messages) == 1
|
len(result_messages) == 1
|
||||||
|
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ToolExecutionStep(
|
step_details=ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
|
||||||
if out_attachment := interpret_content_as_attachment(
|
if out_attachment := _interpret_content_as_attachment(
|
||||||
result_message.content
|
result_message.content
|
||||||
):
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
|
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
async def _get_tool_defs(
|
||||||
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
|
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||||
|
# Determine which tools to include
|
||||||
|
agent_config_toolgroups = set(
|
||||||
|
(
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
|
for toolgroup in self.agent_config.toolgroups
|
||||||
|
)
|
||||||
|
toolgroups_for_turn_set = (
|
||||||
|
agent_config_toolgroups
|
||||||
|
if toolgroups_for_turn is None
|
||||||
|
else {
|
||||||
|
(
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
|
for toolgroup in toolgroups_for_turn
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_def_map = {}
|
||||||
|
tool_to_group = {}
|
||||||
|
|
||||||
|
for tool_def in self.agent_config.client_tools:
|
||||||
|
if tool_def_map.get(tool_def.name, None):
|
||||||
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
|
tool_def_map[tool_def.name] = ToolDefinition(
|
||||||
|
tool_name=tool_def.name,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tool_to_group[tool_def.name] = "__client_tools__"
|
||||||
|
for toolgroup_name in agent_config_toolgroups:
|
||||||
|
if toolgroup_name not in toolgroups_for_turn_set:
|
||||||
|
continue
|
||||||
|
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||||
|
for tool_def in tools:
|
||||||
|
if (
|
||||||
|
toolgroup_name.startswith("builtin")
|
||||||
|
and toolgroup_name != MEMORY_GROUP
|
||||||
|
):
|
||||||
|
tool_name = tool_def.identifier
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
if tool_name == "web_search":
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
else:
|
||||||
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
|
if tool_def_map.get(built_in_type, None):
|
||||||
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
|
tool_def_map[built_in_type] = ToolDefinition(
|
||||||
|
tool_name=built_in_type
|
||||||
|
)
|
||||||
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tool_def_map.get(tool_def.identifier, None):
|
||||||
|
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||||
|
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||||
|
tool_name=tool_def.identifier,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
|
return tool_def_map, tool_to_group
|
||||||
|
|
||||||
|
async def handle_documents(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
documents: List[Document],
|
||||||
|
input_messages: List[Message],
|
||||||
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
|
) -> None:
|
||||||
|
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||||
|
content_items = []
|
||||||
|
url_items = []
|
||||||
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
for d in documents:
|
||||||
|
if isinstance(d.content, URL):
|
||||||
|
url_items.append(d.content)
|
||||||
|
elif pattern.match(d.content):
|
||||||
|
url_items.append(URL(uri=d.content))
|
||||||
|
else:
|
||||||
|
content_items.append(d)
|
||||||
|
|
||||||
|
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
|
if code_interpreter_tool:
|
||||||
|
for c in content_items:
|
||||||
|
temp_file_path = os.path.join(
|
||||||
|
self.tempdir, f"{make_random_string()}.txt"
|
||||||
|
)
|
||||||
|
with open(temp_file_path, "w") as temp_file:
|
||||||
|
temp_file.write(c.content)
|
||||||
|
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
|
||||||
|
if memory_tool and code_interpreter_tool:
|
||||||
|
# if both memory and code_interpreter are available, we download the URLs
|
||||||
|
# and attach the data to the last message.
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
# Since memory is present, add all the data to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
elif code_interpreter_tool:
|
||||||
|
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
|
# and attach the path to them as a message to inference with the
|
||||||
|
# assumption that the model invokes the code_interpreter tool with the path
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
elif memory_tool:
|
||||||
|
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
else:
|
||||||
|
# if no memory or code_interpreter tool is available,
|
||||||
|
# we try to load the data from the URLs and content items as a message to inference
|
||||||
|
# and add it to the last message's context
|
||||||
|
input_messages[-1].context = "\n".join(
|
||||||
|
[doc.content for doc in content_items]
|
||||||
|
+ await load_data_from_urls(url_items)
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
|
@ -679,41 +849,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
return bank_id
|
return bank_id
|
||||||
|
|
||||||
async def _should_retrieve_context(
|
async def add_to_session_memory_bank(
|
||||||
self, messages: List[Message], attachments: List[Attachment]
|
self, session_id: str, data: List[Document]
|
||||||
) -> bool:
|
) -> None:
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
|
||||||
if attachments:
|
|
||||||
if (
|
|
||||||
AgentTool.code_interpreter.value in enabled_tools
|
|
||||||
and self.agent_config.tool_choice == ToolChoice.required
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return AgentTool.memory.value in enabled_tools
|
|
||||||
|
|
||||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
|
||||||
for t in self.agent_config.tools:
|
|
||||||
if t.type == AgentTool.memory.value:
|
|
||||||
return t
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
|
||||||
bank_ids = []
|
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
|
||||||
assert memory is not None, "Memory tool not configured"
|
|
||||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
|
||||||
|
|
||||||
if attachments:
|
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
bank_id = await self._ensure_memory_bank(session_id)
|
||||||
bank_ids.append(bank_id)
|
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
MemoryBankDocument(
|
MemoryBankDocument(
|
||||||
document_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
|
@ -721,87 +860,28 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
mime_type=a.mime_type,
|
mime_type=a.mime_type,
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
for a in attachments
|
for a in data
|
||||||
]
|
]
|
||||||
with tracing.span("insert_documents"):
|
await self.memory_api.insert_documents(
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
|
||||||
else:
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
if session_info.memory_bank_id:
|
|
||||||
bank_ids.append(session_info.memory_bank_id)
|
|
||||||
|
|
||||||
if not bank_ids:
|
|
||||||
# this can happen if the per-session memory bank is not yet populated
|
|
||||||
# (i.e., no prior turns uploaded an Attachment)
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
query = await generate_rag_query(
|
|
||||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
|
||||||
)
|
|
||||||
tasks = [
|
|
||||||
self.memory_api.query_documents(
|
|
||||||
bank_id=bank_id,
|
bank_id=bank_id,
|
||||||
query=query,
|
documents=documents,
|
||||||
params={
|
|
||||||
"max_chunks": 5,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for bank_id in bank_ids
|
|
||||||
]
|
|
||||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
|
||||||
chunks = [c for r in results for c in r.chunks]
|
|
||||||
scores = [s for r in results for s in r.scores]
|
|
||||||
|
|
||||||
if not chunks:
|
|
||||||
return None, bank_ids
|
|
||||||
|
|
||||||
# sort by score
|
|
||||||
chunks, scores = zip(
|
|
||||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = 0
|
|
||||||
picked = []
|
|
||||||
for c in chunks[: memory.max_chunks]:
|
|
||||||
tokens += c.token_count
|
|
||||||
if tokens > memory.max_tokens_in_context:
|
|
||||||
log.error(
|
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
||||||
)
|
|
||||||
break
|
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
|
||||||
|
|
||||||
return (
|
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||||
concat_interleaved_content(
|
data = []
|
||||||
[
|
for url in urls:
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
uri = url.uri
|
||||||
*picked,
|
if uri.startswith("file://"):
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
filepath = uri[len("file://") :]
|
||||||
]
|
with open(filepath, "r") as f:
|
||||||
),
|
data.append(f.read())
|
||||||
bank_ids,
|
elif uri.startswith("http"):
|
||||||
)
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(uri)
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
resp = r.text
|
||||||
ret = []
|
data.append(resp)
|
||||||
for t in self.agent_config.tools:
|
return data
|
||||||
if isinstance(t, SearchToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
|
||||||
elif isinstance(t, WolframAlphaToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
|
||||||
elif isinstance(t, PhotogenToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
|
||||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
|
||||||
elif isinstance(t, FunctionCallToolDefinition):
|
|
||||||
ret.append(
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name=t.function_name,
|
|
||||||
description=t.description,
|
|
||||||
parameters=t.parameters,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||||
|
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
tool_runtime_api: ToolRuntime,
|
||||||
|
session_id: str,
|
||||||
|
messages: List[CompletionMessage],
|
||||||
|
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||||
|
tool_to_group: Dict[str, str],
|
||||||
) -> List[ToolResponseMessage]:
|
) -> List[ToolResponseMessage]:
|
||||||
# While Tools.run interface takes a list of messages,
|
# While Tools.run interface takes a list of messages,
|
||||||
# All tools currently only run on a single message
|
# All tools currently only run on a single message
|
||||||
|
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
assert isinstance(name, BuiltinTool)
|
group_name = tool_to_group.get(name, None)
|
||||||
|
if group_name is None:
|
||||||
|
raise ValueError(f"Tool {name} not found in any tool group")
|
||||||
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
|
tool_call_args = tool_call.arguments
|
||||||
|
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||||
|
if isinstance(name, BuiltinTool):
|
||||||
|
if name == BuiltinTool.brave_search:
|
||||||
|
name = WEB_SEARCH_TOOL
|
||||||
|
else:
|
||||||
name = name.value
|
name = name.value
|
||||||
|
|
||||||
assert name in tools_dict, f"Tool {name} not found"
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool = tools_dict[name]
|
tool_name=name,
|
||||||
result_messages = await tool.run(messages)
|
args=dict(
|
||||||
return result_messages
|
session_id=session_id,
|
||||||
|
**tool_call_args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_content_as_attachment(
|
||||||
|
content: str,
|
||||||
|
) -> Optional[Attachment]:
|
||||||
|
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||||
|
if match:
|
||||||
|
snippet = match.group(1)
|
||||||
|
data = json.loads(snippet)
|
||||||
|
return Attachment(
|
||||||
|
url=URL(uri="file://" + data["filepath"]),
|
||||||
|
mime_type=data["mimetype"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
|
||||||
Agents,
|
Agents,
|
||||||
AgentSessionCreateResponse,
|
AgentSessionCreateResponse,
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
Attachment,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
|
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
memory_banks_api=self.memory_banks_api,
|
memory_banks_api=self.memory_banks_api,
|
||||||
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store
|
self.persistence_store
|
||||||
if agent_config.enable_session_persistence
|
if agent_config.enable_session_persistence
|
||||||
|
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
attachments=attachments,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
documents=documents,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
|
@ -8,13 +8,11 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
Attachment,
|
|
||||||
BuiltinTool,
|
|
||||||
CompletionMessage,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..tools.builtin import CodeInterpreterTool
|
|
||||||
|
|
||||||
|
|
||||||
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
|
|
||||||
async def test_matplotlib(self):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
code = """
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
x = np.array([1, 1])
|
|
||||||
y = np.array([0, 10])
|
|
||||||
|
|
||||||
plt.plot(x, y)
|
|
||||||
plt.title('x = 1')
|
|
||||||
plt.xlabel('x')
|
|
||||||
plt.ylabel('y')
|
|
||||||
plt.grid(True)
|
|
||||||
plt.axvline(x=1, color='r')
|
|
||||||
plt.show()
|
|
||||||
"""
|
|
||||||
message = CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_id",
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
arguments={"code": code},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
)
|
|
||||||
ret = await tool.run([message])
|
|
||||||
|
|
||||||
self.assertEqual(len(ret), 1)
|
|
||||||
|
|
||||||
output = ret[0].content
|
|
||||||
self.assertIsInstance(output, Attachment)
|
|
||||||
self.assertEqual(output.mime_type, "image/png")
|
|
||||||
|
|
||||||
async def test_path_unlink(self):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
code = """
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
dpath = Path(os.environ["MPLCONFIGDIR"])
|
|
||||||
with open(dpath / "test", "w") as f:
|
|
||||||
f.write("hello")
|
|
||||||
|
|
||||||
Path(dpath / "test").unlink()
|
|
||||||
print("_OK_")
|
|
||||||
"""
|
|
||||||
message = CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_id",
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
arguments={"code": code},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
)
|
|
||||||
ret = await tool.run([message])
|
|
||||||
|
|
||||||
self.assertEqual(len(ret), 1)
|
|
||||||
|
|
||||||
output = ret[0].content
|
|
||||||
self.assertTrue("_OK_" in output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
|
@ -4,21 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
StepType,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
|
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
|
||||||
from llama_stack.apis.safety import RunShieldResponse
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
from ..agents import (
|
Tool,
|
||||||
AGENT_INSTANCES_BY_ID,
|
ToolDef,
|
||||||
MetaReferenceAgentsImpl,
|
ToolGroup,
|
||||||
MetaReferenceInferenceConfig,
|
ToolHost,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
|
MEMORY_QUERY_TOOL,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||||
|
MetaReferenceAgentsImpl,
|
||||||
|
MetaReferenceAgentsImplConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MockInferenceAPI:
|
class MockInferenceAPI:
|
||||||
|
@ -48,10 +64,10 @@ class MockInferenceAPI:
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> Union[
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
if stream:
|
async def stream_response():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="start",
|
event_type="start",
|
||||||
|
@ -65,19 +81,7 @@ class MockInferenceAPI:
|
||||||
delta="AI is a fascinating field...",
|
delta="AI is a fascinating field...",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# yield ChatCompletionResponseStreamChunk(
|
|
||||||
# event=ChatCompletionResponseEvent(
|
|
||||||
# event_type="progress",
|
|
||||||
# delta=ToolCallDelta(
|
|
||||||
# content=ToolCall(
|
|
||||||
# call_id="123",
|
|
||||||
# tool_name=BuiltinTool.brave_search.value,
|
|
||||||
# arguments={"query": "AI history"},
|
|
||||||
# ),
|
|
||||||
# parse_status="success",
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="complete",
|
event_type="complete",
|
||||||
|
@ -85,12 +89,17 @@ class MockInferenceAPI:
|
||||||
stop_reason="end_of_turn",
|
stop_reason="end_of_turn",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return stream_response()
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
role="assistant",
|
||||||
|
content="Mock response",
|
||||||
|
stop_reason="end_of_turn",
|
||||||
),
|
),
|
||||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,6 +174,98 @@ class MockMemoryAPI:
|
||||||
self.documents[bank_id].pop(doc_id, None)
|
self.documents[bank_id].pop(doc_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolGroupsAPI:
|
||||||
|
async def register_tool_group(
|
||||||
|
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
return ToolGroup(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||||
|
if tool_group_id == MEMORY_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier=MEMORY_QUERY_TOOL,
|
||||||
|
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||||
|
toolgroup_id=MEMORY_TOOLGROUP,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="builtin::memory",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier="code_interpreter",
|
||||||
|
provider_resource_id="code_interpreter",
|
||||||
|
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="builtin::code_interpreter",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
identifier=tool_name,
|
||||||
|
provider_resource_id=tool_name,
|
||||||
|
toolgroup_id="mock_group",
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="mock_provider",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolRuntimeAPI:
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||||
|
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||||
|
|
||||||
|
|
||||||
|
class MockMemoryBanksAPI:
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def register_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank_id: str,
|
||||||
|
params: BankParams,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
return VectorMemoryBank(
|
||||||
|
identifier=memory_bank_id,
|
||||||
|
provider_resource_id=provider_memory_bank_id or memory_bank_id,
|
||||||
|
embedding_model="mock_model",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_inference_api():
|
def mock_inference_api():
|
||||||
return MockInferenceAPI()
|
return MockInferenceAPI()
|
||||||
|
@ -181,64 +282,107 @@ def mock_memory_api():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
def mock_tool_groups_api():
|
||||||
|
return MockToolGroupsAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_runtime_api():
|
||||||
|
return MockToolRuntimeAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_memory_banks_api():
|
||||||
|
return MockMemoryBanksAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_agents_impl(
|
||||||
|
mock_inference_api,
|
||||||
|
mock_safety_api,
|
||||||
|
mock_memory_api,
|
||||||
|
mock_memory_banks_api,
|
||||||
|
mock_tool_runtime_api,
|
||||||
|
mock_tool_groups_api,
|
||||||
|
):
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config=MetaReferenceInferenceConfig(),
|
config=MetaReferenceAgentsImplConfig(
|
||||||
|
persistence_store=SqliteKVStoreConfig(
|
||||||
|
db_name=sqlite_file.name,
|
||||||
|
),
|
||||||
|
),
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
safety_api=mock_safety_api,
|
safety_api=mock_safety_api,
|
||||||
memory_api=mock_memory_api,
|
memory_api=mock_memory_api,
|
||||||
|
memory_banks_api=mock_memory_banks_api,
|
||||||
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
|
tool_groups_api=mock_tool_groups_api,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent(get_agents_impl):
|
||||||
|
impl = await get_agents_impl
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="test_model",
|
model="test_model",
|
||||||
instructions="You are a helpful assistant.",
|
instructions="You are a helpful assistant.",
|
||||||
sampling_params=SamplingParams(),
|
toolgroups=[],
|
||||||
tools=[
|
|
||||||
# SearchToolDefinition(
|
|
||||||
# name="brave_search",
|
|
||||||
# api_key="test_key",
|
|
||||||
# ),
|
|
||||||
],
|
|
||||||
tool_choice=ToolChoice.auto,
|
tool_choice=ToolChoice.auto,
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
input_shields=[],
|
input_shields=["test_shield"],
|
||||||
output_shields=[],
|
|
||||||
)
|
)
|
||||||
response = await impl.create_agent(agent_config)
|
response = await impl.create_agent(agent_config)
|
||||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
return await impl.get_agent(response.agent_id)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
MEMORY_TOOLGROUP = "builtin::memory"
|
||||||
|
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
toolgroups = request.param
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
return await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_agent_create_session(chat_agent):
|
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||||
session = chat_agent.create_session("Test Session")
|
chat_agent = await get_chat_agent
|
||||||
assert session.session_name == "Test Session"
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
assert session.turns == []
|
|
||||||
assert session.session_id in chat_agent.sessions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Hello")],
|
messages=[UserMessage(content="Hello")],
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
print(responses)
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
assert (
|
||||||
|
len(responses) == 7
|
||||||
|
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||||
assert responses[0].event.payload.turn_id is not None
|
assert responses[0].event.payload.turn_id is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||||
|
chat_agent = await get_chat_agent
|
||||||
messages = [UserMessage(content="Test message")]
|
messages = [UserMessage(content="Test message")]
|
||||||
shields = ["test_shield"]
|
shields = ["test_shield"]
|
||||||
|
|
||||||
|
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
|
||||||
|
|
||||||
assert len(responses) == 2 # StepStart, StepComplete
|
assert len(responses) == 2 # StepStart, StepComplete
|
||||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||||
assert not responses[1].event.payload.step_details.response.is_violation
|
assert not responses[1].event.payload.step_details.violation
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||||
async def test_chat_agent_complex_turn(chat_agent):
|
chat_agent = await get_chat_agent
|
||||||
# Setup
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the turn
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
|
|
||||||
# Check for the presence of different step types
|
|
||||||
step_types = [
|
step_types = [
|
||||||
response.event.payload.step_type
|
response.event.payload.step_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "step_type")
|
if hasattr(response.event.payload, "step_type")
|
||||||
]
|
]
|
||||||
|
|
||||||
assert "shield_call" in step_types, "Shield call step is missing"
|
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||||
assert "inference" in step_types, "Inference step is missing"
|
assert StepType.inference in step_types, "Inference step is missing"
|
||||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
|
||||||
|
|
||||||
# Check for the presence of start and complete events
|
|
||||||
event_types = [
|
event_types = [
|
||||||
response.event.payload.event_type
|
response.event.payload.event_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "event_type")
|
if hasattr(response.event.payload, "event_type")
|
||||||
]
|
]
|
||||||
assert "start" in event_types, "Start event is missing"
|
assert "turn_start" in event_types, "Start event is missing"
|
||||||
assert "complete" in event_types, "Complete event is missing"
|
assert "turn_complete" in event_types, "Complete event is missing"
|
||||||
|
|
||||||
# Check for the presence of tool call
|
|
||||||
tool_calls = [
|
|
||||||
response.event.payload.tool_call
|
|
||||||
for response in responses
|
|
||||||
if hasattr(response.event.payload, "tool_call")
|
|
||||||
]
|
|
||||||
assert any(
|
|
||||||
tool_call
|
|
||||||
for tool_call in tool_calls
|
|
||||||
if tool_call and tool_call.content.get("name") == "memory"
|
|
||||||
), "Memory tool call is missing"
|
|
||||||
|
|
||||||
# Check for the final turn complete event
|
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
for response in responses
|
for response in responses
|
||||||
), "Turn complete event is missing"
|
), "Turn complete event is missing"
|
||||||
|
turn_complete_payload = next(
|
||||||
|
response.event.payload
|
||||||
|
for response in responses
|
||||||
|
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
|
)
|
||||||
|
turn = turn_complete_payload.turn
|
||||||
|
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||||
|
|
||||||
# Verify the turn was added to the session
|
|
||||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
@pytest.mark.asyncio
|
||||||
assert (
|
@pytest.mark.parametrize(
|
||||||
session.turns[0].input_messages == request.messages
|
"toolgroups, expected_memory, expected_code_interpreter",
|
||||||
), "Input messages do not match"
|
[
|
||||||
|
([], False, False), # no tools
|
||||||
|
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||||
|
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||||
|
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_chat_agent_tools(
|
||||||
|
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
||||||
|
):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
chat_agent = await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||||
|
if expected_memory:
|
||||||
|
assert MEMORY_QUERY_TOOL in tool_defs
|
||||||
|
if expected_code_interpreter:
|
||||||
|
assert BuiltinTool.code_interpreter in tool_defs
|
||||||
|
if expected_memory and expected_code_interpreter:
|
||||||
|
# override the tools for turn
|
||||||
|
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||||
|
toolgroups_for_turn=[
|
||||||
|
AgentToolGroupWithArgs(
|
||||||
|
name=MEMORY_TOOLGROUP,
|
||||||
|
args={"memory_banks": ["test_memory_bank"]},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||||
|
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
raise NotImplementedError
|
|
|
@ -1,396 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from .ipython_tool.code_execution import (
|
|
||||||
CodeExecutionContext,
|
|
||||||
CodeExecutionRequest,
|
|
||||||
CodeExecutor,
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
|
||||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
|
||||||
if match:
|
|
||||||
snippet = match.group(1)
|
|
||||||
data = json.loads(snippet)
|
|
||||||
return Attachment(
|
|
||||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SingleMessageBuiltinTool(BaseTool):
|
|
||||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
|
|
||||||
|
|
||||||
message = messages[0]
|
|
||||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
|
||||||
|
|
||||||
tool_call = messages[0].tool_calls[0]
|
|
||||||
|
|
||||||
query = tool_call.arguments["query"]
|
|
||||||
response: str = await self.run_impl(query)
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=response,
|
|
||||||
)
|
|
||||||
return [message]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class PhotogenTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, dump_dir: str) -> None:
|
|
||||||
self.dump_dir = dump_dir
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.photogen.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
"""
|
|
||||||
Implement this to give the model an ability to generate images.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
info = {
|
|
||||||
"filepath": str(image_filepath),
|
|
||||||
"mimetype": "image/png",
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class SearchTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.engine_type = engine
|
|
||||||
if engine == SearchEngineType.bing:
|
|
||||||
self.engine = BingSearch(api_key, **kwargs)
|
|
||||||
elif engine == SearchEngineType.brave:
|
|
||||||
self.engine = BraveSearch(api_key, **kwargs)
|
|
||||||
elif engine == SearchEngineType.tavily:
|
|
||||||
self.engine = TavilySearch(api_key, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown search engine: {engine}")
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.brave_search.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
return await self.engine.search(query)
|
|
||||||
|
|
||||||
|
|
||||||
class BingSearch:
|
|
||||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
|
||||||
headers = {
|
|
||||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
|
||||||
}
|
|
||||||
params = {
|
|
||||||
"count": self.top_k,
|
|
||||||
"textDecorations": True,
|
|
||||||
"textFormat": "HTML",
|
|
||||||
"q": query,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get(url=url, params=params, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
clean = self._clean_response(response.json())
|
|
||||||
return json.dumps(clean)
|
|
||||||
|
|
||||||
def _clean_response(self, search_response):
|
|
||||||
clean_response = []
|
|
||||||
query = search_response["queryContext"]["originalQuery"]
|
|
||||||
if "webPages" in search_response:
|
|
||||||
pages = search_response["webPages"]["value"]
|
|
||||||
for p in pages:
|
|
||||||
selected_keys = {"name", "url", "snippet"}
|
|
||||||
clean_response.append(
|
|
||||||
{k: v for k, v in p.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
if "news" in search_response:
|
|
||||||
clean_news = []
|
|
||||||
news = search_response["news"]["value"]
|
|
||||||
for n in news:
|
|
||||||
selected_keys = {"name", "url", "description"}
|
|
||||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
|
||||||
|
|
||||||
clean_response.append(clean_news)
|
|
||||||
|
|
||||||
return {"query": query, "top_k": clean_response}
|
|
||||||
|
|
||||||
|
|
||||||
class BraveSearch:
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
|
||||||
headers = {
|
|
||||||
"X-Subscription-Token": self.api_key,
|
|
||||||
"Accept-Encoding": "gzip",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
payload = {"q": query}
|
|
||||||
response = requests.get(url=url, params=payload, headers=headers)
|
|
||||||
return json.dumps(self._clean_brave_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_brave_response(self, search_response, top_k=3):
|
|
||||||
query = None
|
|
||||||
clean_response = []
|
|
||||||
if "query" in search_response:
|
|
||||||
if "original" in search_response["query"]:
|
|
||||||
query = search_response["query"]["original"]
|
|
||||||
if "mixed" in search_response:
|
|
||||||
mixed_results = search_response["mixed"]
|
|
||||||
for m in mixed_results["main"][:top_k]:
|
|
||||||
r_type = m["type"]
|
|
||||||
results = search_response[r_type]["results"]
|
|
||||||
if r_type == "web":
|
|
||||||
# For web data - add a single output from the search
|
|
||||||
idx = m["index"]
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"date",
|
|
||||||
"extra_snippets",
|
|
||||||
]
|
|
||||||
cleaned = {
|
|
||||||
k: v for k, v in results[idx].items() if k in selected_keys
|
|
||||||
}
|
|
||||||
elif r_type == "faq":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = ["type", "question", "answer", "title", "url"]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "infobox":
|
|
||||||
idx = m["index"]
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"long_desc",
|
|
||||||
]
|
|
||||||
cleaned = {
|
|
||||||
k: v for k, v in results[idx].items() if k in selected_keys
|
|
||||||
}
|
|
||||||
elif r_type == "videos":
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"url",
|
|
||||||
"title",
|
|
||||||
"description",
|
|
||||||
"date",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "locations":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"coordinates",
|
|
||||||
"postal_address",
|
|
||||||
"contact",
|
|
||||||
"rating",
|
|
||||||
"distance",
|
|
||||||
"zoom_level",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "news":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cleaned = []
|
|
||||||
|
|
||||||
clean_response.append(cleaned)
|
|
||||||
|
|
||||||
return {"query": query, "top_k": clean_response}
|
|
||||||
|
|
||||||
|
|
||||||
class TavilySearch:
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
response = requests.post(
|
|
||||||
"https://api.tavily.com/search",
|
|
||||||
json={"api_key": self.api_key, "query": query},
|
|
||||||
)
|
|
||||||
return json.dumps(self._clean_tavily_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_tavily_response(self, search_response, top_k=3):
|
|
||||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.wolfram_alpha.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
params = {
|
|
||||||
"input": query,
|
|
||||||
"appid": self.api_key,
|
|
||||||
"format": "plaintext",
|
|
||||||
"output": "json",
|
|
||||||
}
|
|
||||||
response = requests.get(
|
|
||||||
self.url,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_wolfram_alpha_response(self, wa_response):
|
|
||||||
remove = {
|
|
||||||
"queryresult": [
|
|
||||||
"datatypes",
|
|
||||||
"error",
|
|
||||||
"timedout",
|
|
||||||
"timedoutpods",
|
|
||||||
"numpods",
|
|
||||||
"timing",
|
|
||||||
"parsetiming",
|
|
||||||
"parsetimedout",
|
|
||||||
"recalculate",
|
|
||||||
"id",
|
|
||||||
"host",
|
|
||||||
"server",
|
|
||||||
"related",
|
|
||||||
"version",
|
|
||||||
{
|
|
||||||
"pods": [
|
|
||||||
"scanner",
|
|
||||||
"id",
|
|
||||||
"error",
|
|
||||||
"expressiontypes",
|
|
||||||
"states",
|
|
||||||
"infos",
|
|
||||||
"position",
|
|
||||||
"numsubpods",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"assumptions",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
for main_key in remove:
|
|
||||||
for key_to_remove in remove[main_key]:
|
|
||||||
try:
|
|
||||||
if key_to_remove == "assumptions":
|
|
||||||
if "assumptions" in wa_response[main_key]:
|
|
||||||
del wa_response[main_key][key_to_remove]
|
|
||||||
if isinstance(key_to_remove, dict):
|
|
||||||
for sub_key in key_to_remove:
|
|
||||||
if sub_key == "pods":
|
|
||||||
for i in range(len(wa_response[main_key][sub_key])):
|
|
||||||
if (
|
|
||||||
wa_response[main_key][sub_key][i]["title"]
|
|
||||||
== "Result"
|
|
||||||
):
|
|
||||||
del wa_response[main_key][sub_key][i + 1 :]
|
|
||||||
break
|
|
||||||
sub_items = wa_response[main_key][sub_key]
|
|
||||||
for i in range(len(sub_items)):
|
|
||||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
|
||||||
if sub_key_to_remove in sub_items[i]:
|
|
||||||
del sub_items[i][sub_key_to_remove]
|
|
||||||
elif key_to_remove in wa_response[main_key]:
|
|
||||||
del wa_response[main_key][key_to_remove]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return wa_response
|
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreterTool(BaseTool):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
ctx = CodeExecutionContext(
|
|
||||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
|
||||||
)
|
|
||||||
self.code_executor = CodeExecutor(ctx)
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.code_interpreter.value
|
|
||||||
|
|
||||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
message = messages[0]
|
|
||||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
|
||||||
|
|
||||||
tool_call = messages[0].tool_calls[0]
|
|
||||||
script = tool_call.arguments["code"]
|
|
||||||
|
|
||||||
req = CodeExecutionRequest(scripts=[script])
|
|
||||||
res = self.code_executor.execute(req)
|
|
||||||
|
|
||||||
pieces = [res["process_status"]]
|
|
||||||
for out_type in ["stdout", "stderr"]:
|
|
||||||
res_out = res[out_type]
|
|
||||||
if res_out != "":
|
|
||||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
|
||||||
if out_type == "stderr":
|
|
||||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content="\n".join(pieces),
|
|
||||||
)
|
|
||||||
return [message]
|
|
|
@ -1,42 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
|
|
||||||
from ..safety import ShieldRunnerMixin
|
|
||||||
from .builtin import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
|
||||||
"""A tool that makes other tools safety enabled"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tool: BaseTool,
|
|
||||||
safety_api: Safety,
|
|
||||||
input_shields: List[str] = None,
|
|
||||||
output_shields: List[str] = None,
|
|
||||||
):
|
|
||||||
self._tool = tool
|
|
||||||
ShieldRunnerMixin.__init__(
|
|
||||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return self._tool.get_name()
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
if self.input_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.input_shields)
|
|
||||||
# run the underlying tool
|
|
||||||
res = await self._tool.run(messages)
|
|
||||||
if self.output_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.output_shields)
|
|
||||||
|
|
||||||
return res
|
|
|
@ -5,5 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
class LocalFSDatasetIOConfig(BaseModel): ...
|
|
||||||
|
class LocalFSDatasetIOConfig(BaseModel):
|
||||||
|
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||||
|
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||||
|
) # Uses SQLite config specific to localfs storage
|
||||||
|
|
|
@ -18,10 +18,14 @@ from llama_stack.apis.datasets import Dataset
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(ABC):
|
class BaseDataset(ABC):
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -86,8 +90,22 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
# local registry for keeping track of datasets within the provider
|
# local registry for keeping track of datasets within the provider
|
||||||
self.dataset_infos = {}
|
self.dataset_infos = {}
|
||||||
|
self.kvstore = None
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
# Load existing datasets from kvstore
|
||||||
|
start_key = DATASETS_PREFIX
|
||||||
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
|
stored_datasets = await self.kvstore.range(start_key, end_key)
|
||||||
|
|
||||||
|
for dataset in stored_datasets:
|
||||||
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
|
dataset_impl = PandasDataframeDataset(dataset)
|
||||||
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||||
|
dataset_def=dataset,
|
||||||
|
dataset_impl=dataset_impl,
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
@ -95,6 +113,12 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Store in kvstore
|
||||||
|
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=key,
|
||||||
|
value=dataset.json(),
|
||||||
|
)
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
dataset_impl = PandasDataframeDataset(dataset)
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||||
dataset_def=dataset,
|
dataset_def=dataset,
|
||||||
|
@ -102,6 +126,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
await self.kvstore.delete(key=key)
|
||||||
del self.dataset_infos[dataset_id]
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -10,10 +10,8 @@ import uuid
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +142,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -156,7 +156,7 @@ class BraintrustScoringImpl(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.openai_api_key:
|
if provider_data is None or not provider_data.openai_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
self.config.openai_api_key = provider_data.openai_api_key
|
self.config.openai_api_key = provider_data.openai_api_key
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||||
|
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: CodeInterpreterToolConfig):
|
||||||
|
self.config = config
|
||||||
|
ctx = CodeExecutionContext(
|
||||||
|
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||||
|
)
|
||||||
|
self.code_executor = CodeExecutor(ctx)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="code_interpreter",
|
||||||
|
description="Execute code",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="code",
|
||||||
|
description="The code to execute",
|
||||||
|
parameter_type="string",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
script = args["code"]
|
||||||
|
req = CodeExecutionRequest(scripts=[script])
|
||||||
|
res = self.code_executor.execute(req)
|
||||||
|
pieces = [res["process_status"]]
|
||||||
|
for out_type in ["stdout", "stderr"]:
|
||||||
|
res_out = res[out_type]
|
||||||
|
if res_out != "":
|
||||||
|
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||||
|
if out_type == "stderr":
|
||||||
|
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||||
|
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -3,3 +3,9 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolConfig(BaseModel):
|
||||||
|
pass
|
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
from .config import MemoryToolRuntimeConfig
|
||||||
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||||
|
impl = MemoryToolRuntimeImpl(
|
||||||
|
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, List, Literal, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["vector"] = "vector"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyvalue"] = "keyvalue"
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyword"] = "keyword"
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["graph"] = "graph"
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||||
|
MemoryQueryGenerator.default.value
|
||||||
|
)
|
||||||
|
sep: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
CustomMemoryQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolConfig(BaseModel):
|
||||||
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeConfig(BaseModel):
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultMemoryQueryGeneratorConfig()
|
||||||
|
)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 5
|
|
@ -4,25 +4,29 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
MemoryQueryGenerator,
|
MemoryQueryGenerator,
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Message, UserMessage
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: MemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -40,21 +44,26 @@ async def generate_rag_query(
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
m_dict = {
|
||||||
|
"messages": [
|
||||||
|
message.model_dump() if isinstance(message, BaseModel) else message
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render(m_dict)
|
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
|
|
||||||
|
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||||
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_random_string(length: int = 8):
|
||||||
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MemoryToolRuntimeConfig,
|
||||||
|
memory_api: Memory,
|
||||||
|
memory_banks_api: MemoryBanks,
|
||||||
|
inference_api: Inference,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.memory_api = memory_api
|
||||||
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="query_memory",
|
||||||
|
description="Retrieve context from memory",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="messages",
|
||||||
|
description="The input messages to search for",
|
||||||
|
parameter_type="array",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||||
|
) -> Optional[List[InterleavedContent]]:
|
||||||
|
if not bank_ids:
|
||||||
|
return None
|
||||||
|
query = await generate_rag_query(
|
||||||
|
self.config.query_generator_config,
|
||||||
|
input_messages,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
tasks = [
|
||||||
|
self.memory_api.query_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
query=query,
|
||||||
|
params={
|
||||||
|
"max_chunks": self.config.max_chunks,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for bank_id in bank_ids
|
||||||
|
]
|
||||||
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||||
|
chunks = [c for r in results for c in r.chunks]
|
||||||
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
chunks, scores = zip(
|
||||||
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
picked = []
|
||||||
|
for c in chunks[: self.config.max_chunks]:
|
||||||
|
tokens += c.token_count
|
||||||
|
if tokens > self.config.max_tokens_in_context:
|
||||||
|
log.error(
|
||||||
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
*picked,
|
||||||
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||||
|
final_args = tool_group.args or {}
|
||||||
|
final_args.update(args)
|
||||||
|
config = MemoryToolConfig()
|
||||||
|
if tool.metadata and tool.metadata.get("config") is not None:
|
||||||
|
config = MemoryToolConfig(**tool.metadata["config"])
|
||||||
|
if "memory_bank_ids" in final_args:
|
||||||
|
bank_ids = final_args["memory_bank_ids"]
|
||||||
|
else:
|
||||||
|
bank_ids = [
|
||||||
|
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||||
|
]
|
||||||
|
if "messages" not in final_args:
|
||||||
|
raise ValueError("messages are required")
|
||||||
|
context = await self._retrieve_context(
|
||||||
|
final_args["messages"],
|
||||||
|
bank_ids,
|
||||||
|
)
|
||||||
|
if context is None:
|
||||||
|
context = []
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=concat_interleaved_content(context), error_code=0
|
||||||
|
)
|
|
@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.safety,
|
Api.safety,
|
||||||
Api.memory,
|
Api.memory,
|
||||||
Api.memory_banks,
|
Api.memory_banks,
|
||||||
|
Api.tool_runtime,
|
||||||
|
Api.tool_groups,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -19,11 +19,58 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::brave-search",
|
provider_type="inline::memory-runtime",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
provider_type="inline::code-interpreter",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||||
|
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="brave-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="bing-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="tavily-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="wolfram-alpha",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
|
@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import CerebrasImplConfig
|
from .config import CerebrasImplConfig
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.1-8b",
|
"llama3.1-8b",
|
||||||
|
@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"databricks-meta-llama-3-1-70b-instruct",
|
"databricks-meta-llama-3-1-70b-instruct",
|
||||||
|
@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
@ -52,7 +51,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-8b-instruct",
|
"fireworks/llama-v3p1-8b-instruct",
|
||||||
|
@ -118,7 +116,7 @@ class FireworksInferenceAdapter(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias_with_just_provider_model_id,
|
build_model_alias_with_just_provider_model_id,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .groq_utils import (
|
from .groq_utils import (
|
||||||
convert_chat_completion_request,
|
convert_chat_completion_request,
|
||||||
convert_chat_completion_response,
|
convert_chat_completion_response,
|
||||||
|
@ -94,9 +95,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
ToolPromptFormat
|
|
||||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
@ -145,6 +144,6 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.groq_api_key:
|
if provider_data is None or not provider_data.groq_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
|
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||||
)
|
)
|
||||||
return Groq(api_key=provider_data.groq_api_key)
|
return Groq(api_key=provider_data.groq_api_key)
|
||||||
|
|
|
@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
ToolPromptFormat
|
|
||||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
build_model_alias_with_just_provider_model_id,
|
build_model_alias_with_just_provider_model_id,
|
||||||
|
@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
|
@ -139,7 +135,7 @@ class TogetherInferenceAdapter(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
return Together(api_key=together_api_key)
|
return Together(api_key=together_api_key)
|
||||||
|
@ -188,7 +184,7 @@ class TogetherInferenceAdapter(
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .bing_search import BingSearchToolRuntimeImpl
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||||
|
impl = BingSearchToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
|
self.config = config
|
||||||
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web using Bing Search API",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": api_key,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": self.config.top_k,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": args["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
url=self.url,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append(
|
||||||
|
{k: v for k, v in p.items() if k in selected_keys}
|
||||||
|
)
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolConfig(BaseModel):
|
||||||
|
"""Configuration for Bing Search Tool Runtime"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
top_k: int = 3
|
|
@ -14,7 +14,7 @@ class BraveSearchToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
||||||
impl = BraveSearchToolRuntimeImpl(config)
|
impl = BraveSearchToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -4,11 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
@ -25,8 +33,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
if tool.identifier != "brave_search":
|
pass
|
||||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
@ -38,12 +45,27 @@ class BraveSearchToolRuntimeImpl(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.api_key
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
async def list_runtime_tools(
|
||||||
raise NotImplementedError("Brave search tool group not supported")
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web for information",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
built_in_type=BuiltinTool.brave_search,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, args: Dict[str, Any]
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel):
|
||||||
default=3,
|
default=3,
|
||||||
description="The maximum number of results to return",
|
description="The maximum number of results to return",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
|
||||||
|
"max_results": 3,
|
||||||
|
}
|
|
@ -4,22 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
MCPToolGroupDef,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroupDef,
|
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
async def list_runtime_tools(
|
||||||
if not isinstance(tool_group, MCPToolGroupDef):
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
) -> List[ToolDef]:
|
||||||
|
if mcp_endpoint is None:
|
||||||
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
async with sse_client(mcp_endpoint.uri) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
|
@ -57,7 +58,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
metadata={
|
metadata={
|
||||||
"endpoint": tool_group.endpoint.uri,
|
"endpoint": mcp_endpoint.uri,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import TavilySearchToolConfig
|
||||||
|
from .tavily_search import TavilySearchToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearchToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
|
||||||
|
impl = TavilySearchToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearchToolConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Tavily Search API Key",
|
||||||
|
)
|
||||||
|
max_results: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="The maximum number of results to return",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
|
||||||
|
"max_results": 3,
|
||||||
|
}
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import TavilySearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearchToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: TavilySearchToolConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web for information",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={"api_key": api_key, "query": args["query"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_tavily_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_tavily_response(self, search_response, top_k=3):
|
||||||
|
return {"query": search_response["query"], "top_k": search_response["results"]}
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import WolframAlphaToolConfig
|
||||||
|
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
|
||||||
|
|
||||||
|
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||||
|
impl = WolframAlphaToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolConfig(BaseModel):
|
||||||
|
"""Configuration for WolframAlpha Tool Runtime"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
|
self.config = config
|
||||||
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="wolfram_alpha",
|
||||||
|
description="Query WolframAlpha for computational knowledge",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to compute",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
params = {
|
||||||
|
"input": args["query"],
|
||||||
|
"appid": api_key,
|
||||||
|
"format": "plaintext",
|
||||||
|
"output": "json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
self.url,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_wolfram_alpha_response(self, wa_response):
|
||||||
|
remove = {
|
||||||
|
"queryresult": [
|
||||||
|
"datatypes",
|
||||||
|
"error",
|
||||||
|
"timedout",
|
||||||
|
"timedoutpods",
|
||||||
|
"numpods",
|
||||||
|
"timing",
|
||||||
|
"parsetiming",
|
||||||
|
"parsetimedout",
|
||||||
|
"recalculate",
|
||||||
|
"id",
|
||||||
|
"host",
|
||||||
|
"server",
|
||||||
|
"related",
|
||||||
|
"version",
|
||||||
|
{
|
||||||
|
"pods": [
|
||||||
|
"scanner",
|
||||||
|
"id",
|
||||||
|
"error",
|
||||||
|
"expressiontypes",
|
||||||
|
"states",
|
||||||
|
"infos",
|
||||||
|
"position",
|
||||||
|
"numsubpods",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"assumptions",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for main_key in remove:
|
||||||
|
for key_to_remove in remove[main_key]:
|
||||||
|
try:
|
||||||
|
if key_to_remove == "assumptions":
|
||||||
|
if "assumptions" in wa_response[main_key]:
|
||||||
|
del wa_response[main_key][key_to_remove]
|
||||||
|
if isinstance(key_to_remove, dict):
|
||||||
|
for sub_key in key_to_remove:
|
||||||
|
if sub_key == "pods":
|
||||||
|
for i in range(len(wa_response[main_key][sub_key])):
|
||||||
|
if (
|
||||||
|
wa_response[main_key][sub_key][i]["title"]
|
||||||
|
== "Result"
|
||||||
|
):
|
||||||
|
del wa_response[main_key][sub_key][i + 1 :]
|
||||||
|
break
|
||||||
|
sub_items = wa_response[main_key][sub_key]
|
||||||
|
for i in range(len(sub_items)):
|
||||||
|
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||||
|
if sub_key_to_remove in sub_items[i]:
|
||||||
|
del sub_items[i][sub_key_to_remove]
|
||||||
|
elif key_to_remove in wa_response[main_key]:
|
||||||
|
del wa_response[main_key][key_to_remove]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return wa_response
|
|
@ -7,13 +7,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
|
@ -21,6 +20,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="meta_reference",
|
||||||
marks=pytest.mark.meta_reference,
|
marks=pytest.mark.meta_reference,
|
||||||
|
@ -31,6 +31,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
|
@ -42,6 +43,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
|
@ -52,6 +54,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="fireworks",
|
id="fireworks",
|
||||||
marks=pytest.mark.fireworks,
|
marks=pytest.mark.fireworks,
|
||||||
|
@ -62,6 +65,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "remote",
|
"safety": "remote",
|
||||||
"memory": "remote",
|
"memory": "remote",
|
||||||
"agents": "remote",
|
"agents": "remote",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="remote",
|
id="remote",
|
||||||
marks=pytest.mark.remote,
|
marks=pytest.mark.remote,
|
||||||
|
@ -117,6 +121,7 @@ def pytest_generate_tests(metafunc):
|
||||||
"safety": SAFETY_FIXTURES,
|
"safety": SAFETY_FIXTURES,
|
||||||
"memory": MEMORY_FIXTURES,
|
"memory": MEMORY_FIXTURES,
|
||||||
"agents": AGENTS_FIXTURES,
|
"agents": AGENTS_FIXTURES,
|
||||||
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
|
|
@ -11,13 +11,12 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput, ModelType
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
MetaReferenceAgentsImplConfig,
|
MetaReferenceAgentsImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,12 +58,18 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def agents_stack(request, inference_model, safety_shield):
|
async def agents_stack(
|
||||||
|
request,
|
||||||
|
inference_model,
|
||||||
|
safety_shield,
|
||||||
|
tool_group_input_memory,
|
||||||
|
tool_group_input_tavily_search,
|
||||||
|
):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["inference", "safety", "memory", "agents"]:
|
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
if key == "inference":
|
if key == "inference":
|
||||||
|
@ -113,10 +118,11 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
)
|
)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=models,
|
models=models,
|
||||||
shields=[safety_shield] if safety_shield else [],
|
shields=[safety_shield] if safety_shield else [],
|
||||||
|
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
|
||||||
)
|
)
|
||||||
return test_stack
|
return test_stack
|
||||||
|
|
|
@ -5,22 +5,17 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTool,
|
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
Attachment,
|
Document,
|
||||||
MemoryToolDefinition,
|
|
||||||
SearchEngineType,
|
|
||||||
SearchToolDefinition,
|
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
|
||||||
#
|
#
|
||||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||||
# -m "meta_reference"
|
# -m "meta_reference"
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
from .utils import create_agent_session
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
@ -51,7 +45,7 @@ def common_params(inference_model):
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
input_shields=[],
|
input_shields=[],
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
tools=[],
|
toolgroups=[],
|
||||||
max_infer_iters=5,
|
max_infer_iters=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,73 +82,6 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_turn_with_search_tool(
|
|
||||||
agents_stack: Dict[str, object],
|
|
||||||
search_query_messages: List[object],
|
|
||||||
common_params: Dict[str, str],
|
|
||||||
search_tool_definition: SearchToolDefinition,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Create an agent turn with a search tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agents_stack (Dict[str, object]): The agents stack.
|
|
||||||
search_query_messages (List[object]): The search query messages.
|
|
||||||
common_params (Dict[str, str]): The common parameters.
|
|
||||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create an agent with the search tool
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
**{
|
|
||||||
**common_params,
|
|
||||||
"tools": [search_tool_definition],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_id, session_id = await create_agent_session(
|
|
||||||
agents_stack.impls[Api.agents], agent_config
|
|
||||||
)
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=search_query_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk
|
|
||||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
|
||||||
**turn_request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
|
||||||
assert all(
|
|
||||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
|
||||||
)
|
|
||||||
|
|
||||||
check_event_types(turn_response)
|
|
||||||
|
|
||||||
# Check for tool execution events
|
|
||||||
tool_execution_events = [
|
|
||||||
chunk
|
|
||||||
for chunk in turn_response
|
|
||||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
|
||||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
|
||||||
]
|
|
||||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
|
||||||
|
|
||||||
# Check the tool execution details
|
|
||||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
|
||||||
assert len(tool_execution.tool_calls) > 0
|
|
||||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
|
||||||
assert len(tool_execution.tool_responses) > 0
|
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(
|
async def test_agent_turns_with_safety(
|
||||||
|
@ -227,7 +154,7 @@ class TestAgents:
|
||||||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rag_agent_as_attachments(
|
async def test_rag_agent(
|
||||||
self,
|
self,
|
||||||
agents_stack,
|
agents_stack,
|
||||||
attachment_message,
|
attachment_message,
|
||||||
|
@ -243,29 +170,17 @@ class TestAgents:
|
||||||
"qat_finetune.rst",
|
"qat_finetune.rst",
|
||||||
"lora_finetune.rst",
|
"lora_finetune.rst",
|
||||||
]
|
]
|
||||||
|
documents = [
|
||||||
attachments = [
|
Document(
|
||||||
Attachment(
|
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"tools": [
|
"toolgroups": ["builtin::memory"],
|
||||||
MemoryToolDefinition(
|
|
||||||
memory_bank_configs=[],
|
|
||||||
query_generator_config={
|
|
||||||
"type": "default",
|
|
||||||
"sep": " ",
|
|
||||||
},
|
|
||||||
max_tokens_in_context=4096,
|
|
||||||
max_chunks=10,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
"tool_choice": ToolChoice.auto,
|
"tool_choice": ToolChoice.auto,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -275,7 +190,7 @@ class TestAgents:
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=attachment_message,
|
messages=attachment_message,
|
||||||
attachments=attachments,
|
documents=documents,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [
|
||||||
|
@ -298,22 +213,6 @@ class TestAgents:
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_agent_turn_with_brave_search(
|
|
||||||
self, agents_stack, search_query_messages, common_params
|
|
||||||
):
|
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
|
||||||
|
|
||||||
search_tool_definition = SearchToolDefinition(
|
|
||||||
type=AgentTool.brave_search.value,
|
|
||||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
|
||||||
engine=SearchEngineType.brave,
|
|
||||||
)
|
|
||||||
await create_agent_turn_with_search_tool(
|
|
||||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_tavily_search(
|
async def test_create_agent_turn_with_tavily_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
|
@ -321,14 +220,57 @@ class TestAgents:
|
||||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
search_tool_definition = SearchToolDefinition(
|
# Create an agent with the toolgroup
|
||||||
type=AgentTool.brave_search.value, # place holder only
|
agent_config = AgentConfig(
|
||||||
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
**{
|
||||||
engine=SearchEngineType.tavily,
|
**common_params,
|
||||||
|
"toolgroups": ["builtin::web_search"],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
await create_agent_turn_with_search_tool(
|
|
||||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_stack.impls[Api.agents], agent_config
|
||||||
)
|
)
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=search_query_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk
|
||||||
|
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||||
|
**turn_request
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
|
)
|
||||||
|
|
||||||
|
check_event_types(turn_response)
|
||||||
|
|
||||||
|
# Check for tool execution events
|
||||||
|
tool_execution_events = [
|
||||||
|
chunk
|
||||||
|
for chunk in turn_response
|
||||||
|
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||||
|
and chunk.event.payload.step_details.step_type
|
||||||
|
== StepType.tool_execution.value
|
||||||
|
]
|
||||||
|
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||||
|
|
||||||
|
# Check the tool execution details
|
||||||
|
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||||
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
|
assert len(tool_execution.tool_calls) > 0
|
||||||
|
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||||
|
assert actual_tool_name == BuiltinTool.brave_search
|
||||||
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
|
||||||
|
|
||||||
def check_event_types(turn_response):
|
def check_event_types(turn_response):
|
||||||
|
|
|
@ -157,4 +157,5 @@ pytest_plugins = [
|
||||||
"llama_stack.providers.tests.scoring.fixtures",
|
"llama_stack.providers.tests.scoring.fixtures",
|
||||||
"llama_stack.providers.tests.eval.fixtures",
|
"llama_stack.providers.tests.eval.fixtures",
|
||||||
"llama_stack.providers.tests.post_training.fixtures",
|
"llama_stack.providers.tests.post_training.fixtures",
|
||||||
|
"llama_stack.providers.tests.tools.fixtures",
|
||||||
]
|
]
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||||
from llama_stack.apis.shields import ShieldInput
|
from llama_stack.apis.shields import ShieldInput
|
||||||
|
from llama_stack.apis.tools import ToolGroupInput
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
@ -43,6 +43,7 @@ async def construct_stack_for_test(
|
||||||
datasets: Optional[List[DatasetInput]] = None,
|
datasets: Optional[List[DatasetInput]] = None,
|
||||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||||
|
tool_groups: Optional[List[ToolGroupInput]] = None,
|
||||||
) -> TestStack:
|
) -> TestStack:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
|
@ -56,6 +57,7 @@ async def construct_stack_for_test(
|
||||||
datasets=datasets or [],
|
datasets=datasets or [],
|
||||||
scoring_fns=scoring_fns or [],
|
scoring_fns=scoring_fns or [],
|
||||||
eval_tasks=eval_tasks or [],
|
eval_tasks=eval_tasks or [],
|
||||||
|
tool_groups=tool_groups or [],
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
try:
|
try:
|
||||||
|
@ -77,7 +79,7 @@ async def construct_stack_for_test(
|
||||||
|
|
||||||
if provider_data:
|
if provider_data:
|
||||||
set_request_provider_data(
|
set_request_provider_data(
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack
|
return test_stack
|
||||||
|
|
65
llama_stack/providers/tests/tools/conftest.py
Normal file
65
llama_stack/providers/tests/tools/conftest.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
from ..safety.fixtures import SAFETY_FIXTURES
|
||||||
|
from .fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
|
},
|
||||||
|
id="together",
|
||||||
|
marks=pytest.mark.together,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for mark in ["together"]:
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--safety-shield",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-Guard-3-1B",
|
||||||
|
help="Specify the safety shield to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "tools_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"safety": SAFETY_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
print(combinations)
|
||||||
|
metafunc.parametrize("tools_stack", combinations, indirect=True)
|
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
from llama_stack.apis.tools import ToolGroupInput
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
provider_type="inline::memory-runtime",
|
||||||
|
config={},
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="tavily-search",
|
||||||
|
provider_type="remote::tavily-search",
|
||||||
|
config={
|
||||||
|
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="wolfram-alpha",
|
||||||
|
provider_type="remote::wolfram-alpha",
|
||||||
|
config={
|
||||||
|
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_group_input_memory() -> ToolGroupInput:
|
||||||
|
return ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::memory",
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||||
|
return ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::web_search",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
|
||||||
|
return ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::wolfram_alpha",
|
||||||
|
provider_id="wolfram-alpha",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def tools_stack(
|
||||||
|
request,
|
||||||
|
inference_model,
|
||||||
|
tool_group_input_memory,
|
||||||
|
tool_group_input_tavily_search,
|
||||||
|
tool_group_input_wolfram_alpha,
|
||||||
|
):
|
||||||
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory", "tool_runtime"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if key == "inference":
|
||||||
|
providers[key].append(
|
||||||
|
Provider(
|
||||||
|
provider_id="tools_memory_provider",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
inference_models = (
|
||||||
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
|
)
|
||||||
|
models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=model,
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id=providers["inference"][0].provider_id,
|
||||||
|
)
|
||||||
|
for model in inference_models
|
||||||
|
]
|
||||||
|
models.append(
|
||||||
|
ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
provider_id="tools_memory_provider",
|
||||||
|
metadata={"embedding_dimension": 384},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_stack = await construct_stack_for_test(
|
||||||
|
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
|
||||||
|
providers,
|
||||||
|
provider_data,
|
||||||
|
models=models,
|
||||||
|
tool_groups=[
|
||||||
|
tool_group_input_tavily_search,
|
||||||
|
tool_group_input_wolfram_alpha,
|
||||||
|
tool_group_input_memory,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return test_stack
|
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
from llama_stack.apis.memory import MemoryBankDocument
|
||||||
|
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||||
|
from llama_stack.apis.tools import ToolInvocationResult
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_query():
|
||||||
|
return "What are the latest developments in quantum computing?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_wolfram_alpha_query():
|
||||||
|
return "What is the square root of 16?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_documents():
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
||||||
|
"""Test the web search tool functionality."""
|
||||||
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||||
|
|
||||||
|
# Execute the tool
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="web_search", args={"query": sample_search_query}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
||||||
|
"""Test the wolfram alpha tool functionality."""
|
||||||
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||||
|
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||||
|
"""Test the memory tool functionality."""
|
||||||
|
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
||||||
|
memory_impl = tools_stack.impls[Api.memory]
|
||||||
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||||
|
|
||||||
|
# Register memory bank
|
||||||
|
await memory_banks_impl.register_memory_bank(
|
||||||
|
memory_bank_id="test_bank",
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert documents into memory
|
||||||
|
await memory_impl.insert_documents(
|
||||||
|
bank_id="test_bank",
|
||||||
|
documents=sample_documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the memory tool
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="memory",
|
||||||
|
args={
|
||||||
|
"messages": [
|
||||||
|
UserMessage(
|
||||||
|
content="What are the main topics covered in the documentation?",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"memory_bank_ids": ["test_bank"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
|
@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
|
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -361,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
|
|
||||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||||
if has_custom_tools:
|
if has_custom_tools:
|
||||||
if request.tool_prompt_format == ToolPromptFormat.json:
|
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
||||||
|
if fmt == ToolPromptFormat.json:
|
||||||
tool_gen = JsonCustomToolGenerator()
|
tool_gen = JsonCustomToolGenerator()
|
||||||
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
|
elif fmt == ToolPromptFormat.function_tag:
|
||||||
tool_gen = FunctionTagCustomToolGenerator()
|
tool_gen = FunctionTagCustomToolGenerator()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
||||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||||
custom_template = tool_gen.gen(custom_tools)
|
custom_template = tool_gen.gen(custom_tools)
|
||||||
|
@ -413,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
|
|
||||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||||
if custom_tools:
|
if custom_tools:
|
||||||
if request.tool_prompt_format != ToolPromptFormat.python_list:
|
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||||
|
if fmt != ToolPromptFormat.python_list:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,8 +9,7 @@ from pathlib import Path
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
from llama_stack.distribution.datatypes import Provider
|
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||||
|
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
|
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
@ -26,6 +25,12 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["inline::meta-reference"],
|
||||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::memory-runtime",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
name = "bedrock"
|
name = "bedrock"
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
|
@ -46,6 +51,20 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
)
|
)
|
||||||
for m in MODEL_ALIASES
|
for m in MODEL_ALIASES
|
||||||
]
|
]
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::memory",
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -61,10 +80,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
"LLAMASTACK_PORT": (
|
"LLAMA_STACK_PORT": (
|
||||||
"5001",
|
"5001",
|
||||||
"Port for the Llama Stack distribution server",
|
"Port for the Llama Stack distribution server",
|
||||||
),
|
),
|
||||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
||||||
name: bedrock
|
name: bedrock
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use AWS Bedrock for running LLM inference and safety
|
description: Use AWS Bedrock for running LLM inference and safety
|
||||||
docker_image: null
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::bedrock
|
- remote::bedrock
|
||||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
||||||
- inline::basic
|
- inline::basic
|
||||||
- inline::llm-as-judge
|
- inline::llm-as-judge
|
||||||
- inline::braintrust
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::memory-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
version: '2'
|
version: '2'
|
||||||
image_name: bedrock
|
image_name: bedrock
|
||||||
docker_image: null
|
|
||||||
conda_env: bedrock
|
conda_env: bedrock
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
@ -11,6 +10,7 @@ apis:
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
|
@ -65,8 +65,24 @@ providers:
|
||||||
provider_type: inline::braintrust
|
provider_type: inline::braintrust
|
||||||
config:
|
config:
|
||||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: memory-runtime
|
||||||
|
provider_type: inline::memory-runtime
|
||||||
|
config: {}
|
||||||
metadata_store:
|
metadata_store:
|
||||||
namespace: null
|
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
|
||||||
models:
|
models:
|
||||||
|
@ -90,3 +106,10 @@ memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::memory
|
||||||
|
provider_id: memory-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
||||||
name: cerebras
|
name: cerebras
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use Cerebras for running LLM inference
|
description: Use Cerebras for running LLM inference
|
||||||
docker_image: null
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::cerebras
|
- remote::cerebras
|
||||||
|
@ -14,4 +13,9 @@ distribution_spec:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::memory-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -9,8 +9,12 @@ from pathlib import Path
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -26,6 +30,12 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"memory": ["inline::meta-reference"],
|
"memory": ["inline::meta-reference"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::memory-runtime",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
inference_provider = Provider(
|
inference_provider = Provider(
|
||||||
|
@ -58,6 +68,20 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"embedding_dimension": 384,
|
"embedding_dimension": 384,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::memory",
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="cerebras",
|
name="cerebras",
|
||||||
|
@ -74,10 +98,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=default_models + [embedding_model],
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
"LLAMASTACK_PORT": (
|
"LLAMA_STACK_PORT": (
|
||||||
"5001",
|
"5001",
|
||||||
"Port for the Llama Stack distribution server",
|
"Port for the Llama Stack distribution server",
|
||||||
),
|
),
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue