Merge branch 'main' into vllm

This commit is contained in:
Fred Reiss 2025-01-08 15:47:58 -08:00 committed by GitHub
commit 73fede90a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
175 changed files with 7948 additions and 876 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv * @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721

View file

@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
| Groq | Hosted | | :heavy_check_mark: | | | |
| Ollama | Single Node | | :heavy_check_mark: | | | | | Ollama | Single Node | | :heavy_check_mark: | | | |
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | | | TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | | [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
@ -127,7 +128,7 @@ You have two ways to install this repository:
conda activate stack conda activate stack
cd llama-stack cd llama-stack
$CONDA_PREFIX/bin/pip install -e . pip install -e .
``` ```
## Documentation ## Documentation

View file

@ -1,9 +1,9 @@
{ {
"bedrock": [ "hf-serverless": [
"aiohttp",
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
"blobfile", "blobfile",
"boto3",
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
@ -11,6 +11,100 @@
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -63,7 +157,7 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu" "torch --index-url https://download.pytorch.org/whl/cpu"
], ],
"hf-endpoint": [ "tgi": [
"aiohttp", "aiohttp",
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
@ -96,11 +190,11 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu" "torch --index-url https://download.pytorch.org/whl/cpu"
], ],
"hf-serverless": [ "bedrock": [
"aiohttp",
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
"blobfile", "blobfile",
"boto3",
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
@ -108,7 +202,6 @@
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -207,6 +300,34 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu" "torch --index-url https://download.pytorch.org/whl/cpu"
], ],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [ "ollama": [
"aiohttp", "aiohttp",
"aiosqlite", "aiosqlite",
@ -240,7 +361,7 @@
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu" "torch --index-url https://download.pytorch.org/whl/cpu"
], ],
"tgi": [ "hf-endpoint": [
"aiohttp", "aiohttp",
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
@ -272,126 +393,5 @@
"uvicorn", "uvicorn",
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu" "torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
] ]
} }

4636
docs/getting_started.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -544,7 +544,7 @@
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
" - config:\n", " - config:\n",
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", " api_key: <...>\n",
" url: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://api.together.xyz/v1</span>\n", " url: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://api.together.xyz/v1</span>\n",
" provider_id: together\n", " provider_id: together\n",
" provider_type: remote::together\n", " provider_type: remote::together\n",
@ -663,7 +663,7 @@
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
" - config:\n", " - config:\n",
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", " api_key: <...>\n",
" url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n", " url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n",
" provider_id: together\n", " provider_id: together\n",
" provider_type: remote::together\n", " provider_type: remote::together\n",

View file

@ -338,8 +338,8 @@ distribution_spec:
inference: remote::ollama inference: remote::ollama
memory: inline::faiss memory: inline::faiss
safety: inline::llama-guard safety: inline::llama-guard
agents: meta-reference agents: inline::meta-reference
telemetry: meta-reference telemetry: inline::meta-reference
image_type: conda image_type: conda
``` ```

View file

@ -8,10 +8,6 @@ building_distro
configuration configuration
``` ```
<!-- self_hosted_distro/index -->
<!-- remote_hosted_distro/index -->
<!-- ondevice_distro/index -->
You can instantiate a Llama Stack in one of the following ways: You can instantiate a Llama Stack in one of the following ways:
- **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library) - **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library)
- **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container. - **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container.
@ -30,11 +26,15 @@ If so, we suggest:
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama)) - {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest: - **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
- {dockerhub}`distribution-together` ([Guide](remote_hosted_distro/index)) - {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
- {dockerhub}`distribution-fireworks` ([Guide](remote_hosted_distro/index)) - {dockerhub}`distribution-fireworks` ([Guide](self_hosted_distro/fireworks))
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest: - **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
- [iOS SDK](ondevice_distro/ios_sdk) - [iOS SDK](ondevice_distro/ios_sdk)
- [Android](ondevice_distro/android_sdk) - [Android](ondevice_distro/android_sdk)
- **Do you want a hosted Llama Stack endpoint?** If so, we suggest:
- [Remote-Hosted Llama Stack Endpoints](remote_hosted_distro/index)
You can also build your own [custom distribution](building_distro). You can also build your own [custom distribution](building_distro).

View file

@ -42,6 +42,7 @@ The following models are available by default:
- `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)` - `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)` - `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)` - `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (fireworks/llama-v3p3-70b-instruct)`
- `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)` - `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)` - `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)`

View file

@ -41,6 +41,7 @@ The following models are available by default:
- `meta-llama/Llama-3.2-3B-Instruct` - `meta-llama/Llama-3.2-3B-Instruct`
- `meta-llama/Llama-3.2-11B-Vision-Instruct` - `meta-llama/Llama-3.2-11B-Vision-Instruct`
- `meta-llama/Llama-3.2-90B-Vision-Instruct` - `meta-llama/Llama-3.2-90B-Vision-Instruct`
- `meta-llama/Llama-3.3-70B-Instruct`
- `meta-llama/Llama-Guard-3-8B` - `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision` - `meta-llama/Llama-Guard-3-11B-Vision`

View file

@ -43,7 +43,7 @@ Configuration for this is available at `distributions/ollama/run.yaml`.
### 3. Use the Llama Stack client SDK ### 3. Use the Llama Stack client SDK
You can interact with the Llama Stack server using various client SDKs. We will use the Python SDK which you can install using: You can interact with the Llama Stack server using various client SDKs. We will use the Python SDK which you can install using the following command. Note that you must be using Python 3.10 or newer:
```bash ```bash
pip install llama-stack-client pip install llama-stack-client
``` ```
@ -51,7 +51,8 @@ pip install llama-stack-client
Let's use the `llama-stack-client` CLI to check the connectivity to the server. Let's use the `llama-stack-client` CLI to check the connectivity to the server.
```bash ```bash
llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT models list llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
llama-stack-client models list
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃ ┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
@ -61,7 +62,7 @@ llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT models list
You can test basic Llama inference completion using the CLI too. You can test basic Llama inference completion using the CLI too.
```bash ```bash
llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT \ llama-stack-client \
inference chat-completion \ inference chat-completion \
--message "hello, what model are you?" --message "hello, what model are you?"
``` ```
@ -153,10 +154,3 @@ if __name__ == "__main__":
- Learn how to [Build Llama Stacks](../distributions/index.md) - Learn how to [Build Llama Stacks](../distributions/index.md)
- See [References](../references/index.md) for more details about the llama CLI and Python SDK - See [References](../references/index.md) for more details about the llama CLI and Python SDK
- For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository. - For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository.
## Thinking out aloud here in terms of what to write in the docs
- how to get a llama stack server running
- what are all the different client sdks
- what are the components of building agents

View file

@ -16,7 +16,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
##### Chatbot ##### Chatbot
```{eval-rst} ```{eval-rst}
.. video:: https://github.com/user-attachments/assets/6ca617e8-32ca-49b2-9774-185020ff5204 .. video:: https://github.com/user-attachments/assets/8d2ef802-5812-4a28-96e1-316038c84cbf
:autoplay: :autoplay:
:playsinline: :playsinline:
:muted: :muted:

View file

@ -47,7 +47,7 @@ This first example walks you through how to evaluate a model candidate served by
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions. - [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
#### 1.1 Running MMMU #### 1.1 Running MMMU
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API. - We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in this [GitHub Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
```python ```python
import datasets import datasets

View file

@ -358,7 +358,7 @@
" if not stream:\n", " if not stream:\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", " cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
" else:\n", " else:\n",
" async for log in EventLogger().log(response):\n", " for log in EventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
"\n", "\n",
"# In a Jupyter Notebook cell, use `await` to call the function\n", "# In a Jupyter Notebook cell, use `await` to call the function\n",
@ -366,16 +366,6 @@
"# To run it in a python file, use this line instead\n", "# To run it in a python file, use this line instead\n",
"# asyncio.run(run_main())\n" "# asyncio.run(run_main())\n"
] ]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9399aecc",
"metadata": {},
"outputs": [],
"source": [
"#fin"
]
} }
], ],
"metadata": { "metadata": {

View file

@ -67,7 +67,7 @@
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"from llama_stack.distribution.datatypes import RemoteProviderConfig\n", "from llama_stack.distribution.datatypes import RemoteProviderConfig\n",
"from llama_stack.apis.safety import * # noqa: F403\n", "from llama_stack.apis.safety import Safety\n",
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"\n", "\n",
"\n", "\n",
@ -127,7 +127,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.11.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View file

@ -45,7 +45,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
--- ---
## Install Dependencies and Set Up Environment ## Install Dependencies and Set Up Environmen
1. **Create a Conda Environment**: 1. **Create a Conda Environment**:
Create a new Conda environment with Python 3.10: Create a new Conda environment with Python 3.10:
@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
Open a new terminal and install `llama-stack`: Open a new terminal and install `llama-stack`:
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack==0.0.55 pip install llama-stack==0.0.61
``` ```
--- ---
@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Set the ENV variables by exporting them to the terminal**: 3. **Set the ENV variables by exporting them to the terminal**:
```bash ```bash
export OLLAMA_URL="http://localhost:11434" export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5051 export LLAMA_STACK_PORT=5001
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
``` ```
@ -104,34 +104,29 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Run the Llama Stack**: 3. **Run the Llama Stack**:
Run the stack with command shared by the API from earlier: Run the stack with command shared by the API from earlier:
```bash ```bash
llama stack run ollama \ llama stack run ollama
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL
--env SAFETY_MODEL=$SAFETY_MODEL \ --env SAFETY_MODEL=$SAFETY_MODEL
--env OLLAMA_URL=$OLLAMA_URL --env OLLAMA_URL=$OLLAMA_URL
``` ```
Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5051`. The server will start and listen on `http://localhost:5001`.
--- ---
## Test with `llama-stack-client` CLI ## Test with `llama-stack-client` CLI
After setting up the server, open a new terminal window and install the llama-stack-client package. After setting up the server, open a new terminal window and configure the llama-stack-client.
1. Install the llama-stack-client package 1. Configure the CLI to point to the llama-stack server.
```bash ```bash
conda activate ollama llama-stack-client configure --endpoint http://localhost:5001
pip install llama-stack-client
```
2. Configure the CLI to point to the llama-stack server.
```bash
llama-stack-client configure --endpoint http://localhost:5051
``` ```
**Expected Output:** **Expected Output:**
```bash ```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5051 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
``` ```
3. Test the CLI by running inference: 2. Test the CLI by running inference:
```bash ```bash
llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon" llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon"
``` ```
@ -153,16 +148,18 @@ After setting up the server, open a new terminal window and install the llama-st
After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`: After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`:
```bash ```bash
curl http://localhost:$LLAMA_STACK_PORT/inference/chat_completion \ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
-H "Content-Type: application/json" \ -H "Content-Type: application/json"
-d '{ -d @- <<EOF
"model": "Llama3.2-3B-Instruct", {
"model_id": "$INFERENCE_MODEL",
"messages": [ "messages": [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2-sentence poem about the moon"} {"role": "user", "content": "Write me a 2-sentence poem about the moon"}
], ],
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512} "sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512}
}' }
EOF
``` ```
You can check the available models with the command `llama-stack-client models list`. You can check the available models with the command `llama-stack-client models list`.
@ -186,16 +183,12 @@ You can check the available models with the command `llama-stack-client models l
You can also interact with the Llama Stack server using a simple Python script. Below is an example: You can also interact with the Llama Stack server using a simple Python script. Below is an example:
### 1. Activate Conda Environment and Install Required Python Packages ### 1. Activate Conda Environmen
The `llama-stack-client` library offers a robust and efficient python methods for interacting with the Llama Stack server.
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack-client
``` ```
Note, the client library gets installed by default if you install the server library
### 2. Create Python Script (`test_llama_stack.py`) ### 2. Create Python Script (`test_llama_stack.py`)
```bash ```bash
touch test_llama_stack.py touch test_llama_stack.py
@ -206,19 +199,28 @@ touch test_llama_stack.py
In `test_llama_stack.py`, write the following code: In `test_llama_stack.py`, write the following code:
```python ```python
from llama_stack_client import LlamaStackClient import os
from llama_stack_client import LlamaStackClien
# Initialize the client # Get the model ID from the environment variable
client = LlamaStackClient(base_url="http://localhost:5051") INFERENCE_MODEL = os.environ.get("INFERENCE_MODEL")
# Create a chat completion request # Check if the environment variable is se
if INFERENCE_MODEL is None:
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
# Initialize the clien
client = LlamaStackClient(base_url="http://localhost:5001")
# Create a chat completion reques
response = client.inference.chat_completion( response = client.inference.chat_completion(
messages=[ messages=[
{"role": "system", "content": "You are a friendly assistant."}, {"role": "system", "content": "You are a friendly assistant."},
{"role": "user", "content": "Write a two-sentence poem about llama."} {"role": "user", "content": "Write a two-sentence poem about llama."}
], ],
model_id=MODEL_NAME, model_id=INFERENCE_MODEL,
) )
# Print the response # Print the response
print(response.completion_message.content) print(response.completion_message.content)
``` ```

View file

@ -18,18 +18,30 @@ from typing import (
Union, Union,
) )
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type @json_schema_type

View file

@ -6,13 +6,14 @@
from typing import Optional from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.inference import ToolResponseMessage
class LogEvent: class LogEvent:
def __init__( def __init__(

View file

@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import (
from llama_stack.apis.inference import * # noqa: F403 CompletionMessage,
InterleavedContent,
LogProbConfig,
Message,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
@json_schema_type @json_schema_type

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, field_serializer, model_validator
@json_schema_type @json_schema_type
@ -27,6 +28,12 @@ class _URLOrData(BaseModel):
return values return values
return {"url": values} return {"url": values}
@field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info):
if data is None:
return None
return base64.b64encode(data).decode("utf-8")
@json_schema_type @json_schema_type
class ImageContentItem(_URLOrData): class ImageContentItem(_URLOrData):

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import Dataset
@json_schema_type @json_schema_type

View file

@ -4,18 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@json_schema_type @json_schema_type

View file

@ -7,7 +7,9 @@
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any,
AsyncIterator, AsyncIterator,
Dict,
List, List,
Literal, Literal,
Optional, Optional,
@ -32,8 +34,9 @@ from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):

View file

@ -7,17 +7,17 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type @json_schema_type
@ -58,6 +58,7 @@ class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
max_steps_per_epoch: int max_steps_per_epoch: int
gradient_accumulation_steps: int gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig data_config: DataConfig
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -18,6 +18,8 @@ class ResourceType(Enum):
dataset = "dataset" dataset = "dataset"
scoring_function = "scoring_function" scoring_function = "scoring_function"
eval_task = "eval_task" eval_task = "eval_task"
tool = "tool"
tool_group = "tool_group"
class Resource(BaseModel): class Resource(BaseModel):

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.scoring_functions import * # noqa: F403
# mapping of metric to value # mapping of metric to value
@ -48,7 +47,7 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@ -56,5 +55,5 @@ class Scoring(Protocol):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -6,13 +6,12 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .tools import * # noqa: F401 F403

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class ToolDef(BaseModel):
name: str
description: str
parameters: List[ToolParameter]
metadata: Dict[str, Any]
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
)
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
@json_schema_type
class ToolInvocationResult(BaseModel):
content: InterleavedContent
error_message: Optional[str] = None
error_code: Optional[int] = None
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
"""Register a tool group"""
...
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
async def list_tool_groups(self) -> List[ToolGroup]:
"""List tool groups with optional provider"""
...
@webmethod(route="/tools/list", method="GET")
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
"""List tools with optional tool group"""
...
@webmethod(route="/tools/get", method="GET")
async def get_tool(self, tool_name: str) -> Tool: ...
@webmethod(route="/toolgroups/unregister", method="POST")
async def unregister_tool_group(self, tool_group_id: str) -> None:
"""Unregister a tool group"""
...
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...

View file

@ -6,11 +6,12 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""

View file

@ -3,21 +3,28 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
import os import os
import shutil import shutil
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Optional
import pkg_resources import pkg_resources
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -100,7 +107,7 @@ class StackBuild(Subcommand):
build_config.image_type = args.image_type build_config.image_type = args.image_type
else: else:
self.parser.error( self.parser.error(
f"Please specify a image-type (docker | conda) for {args.template}" f"Please specify a image-type (docker | conda | venv) for {args.template}"
) )
self._run_stack_build_command_from_build_config( self._run_stack_build_command_from_build_config(
build_config, template_name=args.template build_config, template_name=args.template
@ -122,7 +129,7 @@ class StackBuild(Subcommand):
) )
image_type = prompt( image_type = prompt(
"> Enter the image type you want your Llama Stack to be built as (docker or conda): ", "> Enter the image type you want your Llama Stack to be built as (docker or conda or venv): ",
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in ["docker", "conda", "venv"], lambda x: x in ["docker", "conda", "venv"],
error_message="Invalid image type, please enter conda or docker or venv", error_message="Invalid image type, please enter conda or docker or venv",

View file

@ -6,21 +6,22 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import List
from pathlib import Path
from typing import Dict, List
import pkg_resources import pkg_resources
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -126,7 +126,7 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--templat
EOF EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile" printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile\n\n"
cat $TEMP_DIR/Dockerfile cat $TEMP_DIR/Dockerfile
printf "\n" printf "\n"

View file

@ -6,10 +6,14 @@
import logging import logging
import textwrap import textwrap
from typing import Any from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
@ -17,10 +21,7 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -4,23 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List, Optional, Union from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
@ -37,6 +38,8 @@ RoutableObject = Union[
Dataset, Dataset,
ScoringFn, ScoringFn,
EvalTask, EvalTask,
Tool,
ToolGroup,
] ]
@ -48,6 +51,8 @@ RoutableObjectWithProvider = Annotated[
Dataset, Dataset,
ScoringFn, ScoringFn,
EvalTask, EvalTask,
Tool,
ToolGroup,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -59,6 +64,7 @@ RoutedProtocol = Union[
DatasetIO, DatasetIO,
Scoring, Scoring,
Eval, Eval,
ToolRuntime,
] ]

View file

@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.eval_tasks, routing_table_api=Api.eval_tasks,
router_api=Api.eval, router_api=Api.eval,
), ),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
] ]

View file

@ -5,12 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class DistributionInspectConfig(BaseModel): class DistributionInspectConfig(BaseModel):

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import logging
import os import os
import queue import queue
import threading import threading
@ -16,7 +17,6 @@ from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx import httpx
import yaml import yaml
from llama_stack_client import ( from llama_stack_client import (
APIResponse, APIResponse,
@ -28,7 +28,6 @@ from llama_stack_client import (
) )
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from rich.console import Console from rich.console import Console
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
@ -39,9 +38,9 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
@ -67,6 +66,7 @@ def in_notebook():
def stream_across_asyncio_run_boundary( def stream_across_asyncio_run_boundary(
async_gen_maker, async_gen_maker,
pool_executor: ThreadPoolExecutor, pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
result_queue = queue.Queue() result_queue = queue.Queue()
stop_event = threading.Event() stop_event = threading.Event()
@ -74,6 +74,7 @@ def stream_across_asyncio_run_boundary(
async def consumer(): async def consumer():
# make sure we make the generator in the event loop context # make sure we make the generator in the event loop context
gen = await async_gen_maker() gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
try: try:
async for item in await gen: async for item in await gen:
result_queue.put(item) result_queue.put(item)
@ -85,6 +86,7 @@ def stream_across_asyncio_run_boundary(
finally: finally:
result_queue.put(StopIteration) result_queue.put(StopIteration)
stop_event.set() stop_event.set()
await end_trace()
def run_async(): def run_async():
# Run our own loop to avoid double async generator cleanup which is done # Run our own loop to avoid double async generator cleanup which is done
@ -170,6 +172,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
): ):
super().__init__() super().__init__()
@ -177,23 +180,56 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name, custom_provider_registry config_path_or_template_name, custom_provider_registry
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
def initialize(self): def initialize(self):
if in_notebook(): if in_notebook():
import nest_asyncio import nest_asyncio
nest_asyncio.apply() nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize()) return asyncio.run(self.async_client.initialize())
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"): if kwargs.get("stream"):
return stream_across_asyncio_run_boundary( return stream_across_asyncio_run_boundary(
lambda: self.async_client.request(*args, **kwargs), lambda: self.async_client.request(*args, **kwargs),
self.pool_executor, self.pool_executor,
path=path,
) )
else: else:
return asyncio.run(self.async_client.request(*args, **kwargs))
async def _traced_request():
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)
finally:
await end_trace()
return asyncio.run(_traced_request())
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -206,7 +242,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
# of our logs are intended for server-side usage # of our logs are intended for server-side usage
os.environ["TELEMETRY_SINKS"] = "sqlite" current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(
sink for sink in current_sinks if sink != "console"
)
if config_path_or_template_name.endswith(".yaml"): if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name) config_path = Path(config_path_or_template_name)
@ -247,7 +286,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
console = Console() console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() endpoints = get_all_api_endpoints()
endpoint_impls = {} endpoint_impls = {}
@ -295,41 +337,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"}) func = self.endpoint_impls.get(path)
try: if not func:
func = self.endpoint_impls.get(path) raise ValueError(f"No endpoint found for {path}")
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
result = await func(**body) result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=json_content.encode("utf-8"), content=json_content.encode("utf-8"),
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
request=httpx.Request( request=httpx.Request(
method=options.method, method=options.method,
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers, headers=options.headers,
json=options.json_data, json=options.json_data,
), ),
) )
response = APIResponse( response = APIResponse(
raw=mock_response, raw=mock_response,
client=self, client=self,
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream=False, stream=False,
stream_cls=None, stream_cls=None,
) )
return response.parse() return response.parse()
finally:
await end_trace()
async def _call_streaming( async def _call_streaming(
self, self,
@ -341,51 +379,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"}) func = self.endpoint_impls.get(path)
try: if not func:
func = self.endpoint_impls.get(path) raise ValueError(f"No endpoint found for {path}")
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
async def gen(): async def gen():
async for chunk in await func(**body): async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk)) data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n" sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8") yield sse_event.encode("utf-8")
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=gen(), content=gen(),
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
request=httpx.Request( request=httpx.Request(
method=options.method, method=options.method,
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers, headers=options.headers,
json=options.json_data, json=options.json_data,
), ),
) )
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream # so we need to convert it to AsyncStream
args = get_args(stream_cls) args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]] stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse( response = AsyncAPIResponse(
raw=mock_response, raw=mock_response,
client=self, client=self,
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream=True, stream=True,
stream_cls=stream_cls, stream_cls=stream_cls,
) )
return await response.parse() return await response.parse()
finally:
await end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: if not body:

View file

@ -6,14 +6,10 @@
import importlib import importlib
import inspect import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import logging import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -30,11 +26,34 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,
EvalTasksProtocolPrivate,
InlineProviderSpec,
MemoryBanksProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -60,12 +79,15 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.eval: Eval, Api.eval: Eval,
Api.eval_tasks: EvalTasks, Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining, Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
} }
def additional_protocols_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from .routing_tables import ( from .routing_tables import (
DatasetsRoutingTable, DatasetsRoutingTable,
@ -17,6 +18,7 @@ from .routing_tables import (
ModelsRoutingTable, ModelsRoutingTable,
ScoringFunctionsRoutingTable, ScoringFunctionsRoutingTable,
ShieldsRoutingTable, ShieldsRoutingTable,
ToolGroupsRoutingTable,
) )
@ -33,6 +35,7 @@ async def get_routing_table_impl(
"datasets": DatasetsRoutingTable, "datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable, "eval_tasks": EvalTasksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
} }
if api.value not in api_to_tables: if api.value not in api_to_tables:
@ -51,6 +54,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
MemoryRouter, MemoryRouter,
SafetyRouter, SafetyRouter,
ScoringRouter, ScoringRouter,
ToolRuntimeRouter,
) )
api_to_routers = { api_to_routers = {
@ -60,6 +64,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"datasetio": DatasetIORouter, "datasetio": DatasetIORouter,
"scoring": ScoringRouter, "scoring": ScoringRouter,
"eval": EvalRouter, "eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
} }
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")

View file

@ -6,15 +6,40 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio.datasetio import DatasetIO from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.distribution.datatypes import RoutingTable from llama_stack.apis.models import ModelType
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.scoring import (
from llama_stack.apis.safety import * # noqa: F403 ScoreBatchResponse,
from llama_stack.apis.datasetio import * # noqa: F403 ScoreResponse,
from llama_stack.apis.scoring import * # noqa: F403 Scoring,
from llama_stack.apis.eval import * # noqa: F403 ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
class MemoryRouter(Memory): class MemoryRouter(Memory):
@ -329,7 +354,6 @@ class EvalRouter(Eval):
task_config=task_config, task_config=task_config,
) )
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows( async def evaluate_rows(
self, self,
task_id: str, task_id: str,
@ -372,3 +396,28 @@ class EvalRouter(Eval):
task_id, task_id,
job_id, job_id,
) )
class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)

View file

@ -8,19 +8,40 @@ from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.memory_banks import (
BankParams,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, RoutingTable
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:
@ -45,6 +66,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
return await p.register_scoring_function(obj) return await p.register_scoring_function(obj)
elif api == Api.eval: elif api == Api.eval:
return await p.register_eval_task(obj) return await p.register_eval_task(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")
@ -57,6 +80,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_model(obj.identifier) return await p.unregister_model(obj.identifier)
elif api == Api.datasetio: elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier) return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
else: else:
raise ValueError(f"Unregister not supported for {api}") raise ValueError(f"Unregister not supported for {api}")
@ -104,6 +129,8 @@ class CommonRoutingTableImpl(RoutingTable):
await add_objects(scoring_functions, pid, ScoringFn) await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval: elif api == Api.eval:
p.eval_task_store = self p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
@ -125,6 +152,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function") return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable): elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task") return ("Eval", "eval_task")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")
@ -461,3 +490,88 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
provider_resource_id=provider_eval_task_id, provider_resource_id=provider_eval_task_id,
) )
await self.register_object(eval_task) await self.register_object(eval_task)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
tools = []
tool_defs = []
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -28,25 +28,29 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
start_trace, start_trace,
) )
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -235,7 +239,12 @@ def main():
"--template", "--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)", help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
) )
parser.add_argument("--port", type=int, default=5000, help="Port to listen on") parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMASTACK_PORT", 5000)),
help="Port to listen on",
)
parser.add_argument( parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support" "--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
) )
@ -277,7 +286,8 @@ def main():
config = StackRunConfig(**config) config = StackRunConfig(**config)
print("Run configuration:") print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2)) safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)

View file

@ -8,32 +8,31 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, Optional
import pkg_resources import pkg_resources
import yaml import yaml
from termcolor import colored from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.agents import Agents
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.eval import Eval
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.eval import * # noqa: F403 from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inspect import Inspect
from llama_stack.apis.batch_inference import * # noqa: F403 from llama_stack.apis.memory import Memory
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.models import Models
from llama_stack.apis.post_training import * # noqa: F403 from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.synthetic_data_generation import * # noqa: F403 from llama_stack.apis.safety import Safety
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.scoring import Scoring
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.shields import Shields
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
@ -113,6 +112,26 @@ class EnvVarError(Exception):
) )
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any: def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict): if isinstance(config, dict):
result = {} result = {}

View file

@ -90,7 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
$env_vars \ $env_vars \
-v "$yaml_config:/app/config.yaml" \ -v "$yaml_config:/app/config.yaml" \
$mounts \ $mounts \
$docker_image:$version_tag \ --env LLAMASTACK_PORT=$port \
python -m llama_stack.distribution.server.server \ --entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
--yaml-config /app/config.yaml \ $docker_image:$version_tag
--port "$port"

View file

@ -13,11 +13,8 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import ( from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
KVStore, from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):

View file

@ -8,11 +8,14 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@pytest.fixture @pytest.fixture

View file

@ -129,7 +129,7 @@ def application_evaluation_page():
# Display current row results using separate containers # Display current row results using separate containers
progress_text_container.write( progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})" f"Expand to see current processed result ({i + 1} / {len(rows)})"
) )
results_container.json( results_container.json(
score_res.to_json(), score_res.to_json(),

View file

@ -232,7 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write( progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})" f"Expand to see current processed result ({i + 1} / {len(rows)})"
) )
results_container.json(eval_res, expanded=2) results_container.json(eval_res, expanded=2)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool
@json_schema_type @json_schema_type
@ -29,6 +30,7 @@ class Api(Enum):
scoring = "scoring" scoring = "scoring"
eval = "eval" eval = "eval"
post_training = "post_training" post_training = "post_training"
tool_runtime = "tool_runtime"
telemetry = "telemetry" telemetry = "telemetry"
@ -38,6 +40,7 @@ class Api(Enum):
datasets = "datasets" datasets = "datasets"
scoring_functions = "scoring_functions" scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks" eval_tasks = "eval_tasks"
tool_groups = "tool_groups"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@ -75,6 +78,12 @@ class EvalTasksProtocolPrivate(Protocol):
async def register_eval_task(self, eval_task: EvalTask) -> None: ... async def register_eval_task(self, eval_task: EvalTask) -> None: ...
class ToolsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ...
@json_schema_type @json_schema_type
class ProviderSpec(BaseModel): class ProviderSpec(BaseModel):
api: Api api: Api

View file

@ -13,19 +13,64 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import (
from llama_stack.apis.inference import * # noqa: F403 AgentConfig,
from llama_stack.apis.memory import * # noqa: F403 AgentTool,
from llama_stack.apis.memory_banks import * # noqa: F403 AgentTurnCreateRequest,
from llama_stack.apis.safety import * # noqa: F403 AgentTurnResponseEvent,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
Inference,
Message,
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
@ -539,7 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
if not isinstance(name, BuiltinTool): if not isinstance(name, BuiltinTool) or name not in enabled_tools:
yield message yield message
return return

View file

@ -9,15 +9,26 @@ import logging
import shutil import shutil
import tempfile import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored from termcolor import colored
from llama_stack.apis.inference import Inference from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTurnCreateRequest,
Attachment,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl

View file

@ -10,9 +10,11 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,8 +7,6 @@
from typing import List from typing import List
from jinja2 import Template from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -64,7 +62,7 @@ async def llm_rag_query_generator(
model = config.model model = config.model
message = UserMessage(content=content) message = UserMessage(content=content)
response = await inference_api.chat_completion( response = await inference_api.chat_completion(
model=model, model_id=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )

View file

@ -9,7 +9,9 @@ import logging
from typing import List from typing import List
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union
import pytest import pytest
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.agents import (
from llama_stack.apis.memory import * # noqa: F403 AgentConfig,
from llama_stack.apis.safety import * # noqa: F403 AgentTurnCreateRequest,
from llama_stack.apis.agents import * # noqa: F403 AgentTurnResponseTurnCompletePayload,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import ( from ..agents import (
AGENT_INSTANCES_BY_ID, AGENT_INSTANCES_BY_ID,

View file

@ -7,7 +7,7 @@
from typing import List from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin from ..safety import ShieldRunnerMixin
from .builtin import BaseTool from .builtin import BaseTool

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403 from pydantic import BaseModel
class LocalFSDatasetIOConfig(BaseModel): ... class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -3,18 +3,19 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64 import base64
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url

View file

@ -3,37 +3,38 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm from tqdm import tqdm
from .....apis.common.job_types import Job from llama_stack.apis.agents import Agents, StepType
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"
class ColumnName(Enum): class MetaReferenceEvalImpl(
input_query = "input_query" Eval,
expected_answer = "expected_answer" EvalTasksProtocolPrivate,
chat_completion_input = "chat_completion_input" ):
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__( def __init__(
self, self,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
@ -77,29 +78,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) )
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_eval( async def run_eval(
self, self,
task_id: str, task_id: str,
@ -109,8 +87,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(
@ -162,11 +142,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) )
] ]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
generations.append(
{ # check if there's a memory retrieval step and extract the context
ColumnName.generated_answer.value: final_event.turn.output_message.content memory_rag_context = None
} for step in final_event.turn.steps:
if step.step_type == StepType.memory_retrieval.value:
memory_rag_context = " ".join(x.text for x in step.inserted_context)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
) )
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations return generations

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .config import ( from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir from .generation import Llama, model_checkpoint_dir
@ -27,9 +30,9 @@ class ModelRunner:
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any): def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest): if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req) return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest): elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req) return self.llama.completion(req)
else: else:
raise ValueError(f"Unexpected task type {type(req)}") raise ValueError(f"Unexpected task type {type(req)}")
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
def completion( def completion(
self, self,
request: CompletionRequest, request: CompletionRequestWithRawContent,
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(req_obj)
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequestWithRawContent,
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(req_obj)

View file

@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .generation import TokenResult from .generation import TokenResult
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ( type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request ProcessingMessageName.task_request
) )
task: Union[CompletionRequest, ChatCompletionRequest] task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
@ -264,9 +267,6 @@ def launch_dist_group(
init_model_cb: Callable, init_model_cb: Callable,
**kwargs, **kwargs,
) -> None: ) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig( launch_config = LaunchConfig(
@ -315,7 +315,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready # wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest())) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() _response = request_socket.recv()
log.info("Loaded model...") log.info("Loaded model...")
return request_socket, process return request_socket, process
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
self.started = False self.started = False
def run_inference( def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest] self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"

View file

@ -7,10 +7,10 @@
import logging import logging
import os import os
import uuid import uuid
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,

View file

@ -16,11 +16,14 @@ import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (

View file

@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference # copy the related files for inference
shutil.copy( source_path = Path.joinpath(self._checkpoint_dir, "params.json")
Path.joinpath(self._checkpoint_dir, "params.json"), if source_path.exists():
Path.joinpath(model_file_path, "params.json"), shutil.copy(
) source_path,
shutil.copy( Path.joinpath(model_file_path, "params.json"),
Path.joinpath(self._checkpoint_dir, "tokenizer.model"), )
Path.joinpath(model_file_path, "tokenizer.model"), source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
) if source_path.exists():
shutil.copy( shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"), source_path,
Path.joinpath(model_file_path, "orig_params.json"), Path.joinpath(model_file_path, "tokenizer.model"),
) )
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only: if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY] model_state_dict = state_dict[training.MODEL_KEY]

View file

@ -14,14 +14,16 @@ from enum import Enum
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
import torch import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from pydantic import BaseModel
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
@ -48,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2", checkpoint_type="LLAMA3_2",
), ),
"Llama-3-8B-Instruct": ModelConfig( "Llama3.1-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b, model_definition=lora_llama3_1_8b,
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3", checkpoint_type="LLAMA3",
), ),

View file

@ -3,11 +3,26 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import webmethod
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice, LoraFinetuningSingleDevice,
) )

View file

@ -7,6 +7,7 @@
import logging import logging
import os import os
import time import time
from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -14,27 +15,33 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
@ -43,11 +50,12 @@ from torchtune.modules.peft import (
get_adapter_state_dict, get_adapter_state_dict,
get_lora_module_names, get_lora_module_names,
get_merged_lora_ckpt, get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params, set_trainable_params,
validate_missing_and_unexpected_for_lora, validate_missing_and_unexpected_for_lora,
) )
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -110,6 +118,10 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = config.checkpoint_dir self.checkpoint_dir = config.checkpoint_dir
else: else:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -125,6 +137,7 @@ class LoraFinetuningSingleDevice:
self.global_step = 0 self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
@ -277,7 +290,6 @@ class LoraFinetuningSingleDevice:
for m in model.modules(): for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"): if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude() m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict: if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict( lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False lora_weights_state_dict, strict=False
@ -572,7 +584,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...") log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader)) pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader): for idx, batch in enumerate(self._validation_dataloader):
if idx == 10: if idx == self.max_validation_steps:
break break
torchtune_utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)

View file

@ -7,8 +7,14 @@
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )

View file

@ -9,10 +9,24 @@ import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.datatypes import CoreModelId
from llama_stack.apis.inference import * # noqa: F403 from llama_models.llama3.api.datatypes import Role
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
Inference,
Message,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate

View file

@ -11,11 +11,16 @@ import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import Message
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import (
from llama_stack.apis.safety import * # noqa: F403 RunShieldResponse,
from llama_models.llama3.api.datatypes import * # noqa: F403 Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -3,16 +3,24 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring import (
from llama_stack.apis.common.type_system import * # noqa: F403 ScoreBatchResponse,
from llama_stack.apis.datasetio import * # noqa: F403 ScoreResponse,
from llama_stack.apis.datasets import * # noqa: F403 Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -21,7 +29,10 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class BasicScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__( def __init__(
self, self,
config: BasicScoringConfig, config: BasicScoringConfig,
@ -58,30 +69,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn): class EqualityScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
""" """

View file

@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import ( from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer, regex_parser_multiple_choice_answer,
) )
class RegexParserScoringFn(BaseScoringFn): class RegexParserScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that parses answer from generated response according to context and check match with expected_answer. A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
""" """

View file

@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn): class SubsetOfScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
""" """

View file

@ -3,32 +3,115 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
import os import os
from typing import Any, Dict, List, Optional
from autoevals.llm import Factuality from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextEntityRecall,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
from pydantic import BaseModel
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
class BraintrustScoringFnEntry(BaseModel):
identifier: str
evaluator: Any
fn_def: ScoringFn
SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
BraintrustScoringFnEntry(
identifier="braintrust::factuality",
evaluator=Factuality(),
fn_def=factuality_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-correctness",
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-entity-recall",
evaluator=ContextEntityRecall(),
fn_def=context_entity_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-precision",
evaluator=ContextPrecision(),
fn_def=context_precision_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-recall",
evaluator=ContextRecall(),
fn_def=context_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-relevancy",
evaluator=ContextRelevancy(),
fn_def=context_relevancy_fn_def,
),
]
class BraintrustScoringImpl( class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
): ):
def __init__( def __init__(
self, self,
@ -41,12 +124,12 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.braintrust_evaluators = { self.braintrust_evaluators = {
"braintrust::factuality": Factuality(), entry.identifier: entry.evaluator
"braintrust::answer-correctness": AnswerCorrectness(), for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def, entry.identifier: entry.fn_def
answer_correctness_fn_def.identifier: answer_correctness_fn_def, for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -67,23 +150,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider" "Registering scoring function not allowed for braintrust provider"
) )
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
if not self.config.openai_api_key: if not self.config.openai_api_key:
@ -99,11 +165,16 @@ class BraintrustScoringImpl(
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.set_api_key() await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
@ -123,6 +194,7 @@ class BraintrustScoringImpl(
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow: ) -> ScoringResultRow:
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key() await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
@ -130,12 +202,19 @@ class BraintrustScoringImpl(
input_query = input_row["input_query"] input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier] evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query) result = evaluator(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score score = result.score
return {"score": score, "metadata": result.metadata} return {"score": score, "metadata": result.metadata}
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ) -> ScoreResponse:
await self.set_api_key() await self.set_api_key()
res = {} res = {}
@ -147,8 +226,17 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id) await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows for input_row in input_rows
] ]
aggregation_functions = [AggregationFunctionType.average] aggregation_functions = self.supported_fn_defs_registry[
agg_results = aggregate_average(score_results) scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions
agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -3,7 +3,9 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403 from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class BraintrustScoringConfig(BaseModel): class BraintrustScoringConfig(BaseModel):

View file

@ -5,14 +5,23 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_correctness_fn_def = ScoringFn( answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness", identifier="braintrust::answer-correctness",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", description=(
params=None, "Scores the correctness of the answer based on the ground truth. "
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="answer-correctness", provider_resource_id="answer-correctness",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_entity_recall_fn_def = ScoringFn(
identifier="braintrust::context-entity-recall",
description=(
"Evaluates how well the context captures the named entities present in the "
"reference answer. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_precision_fn_def = ScoringFn(
identifier="braintrust::context-precision",
description=(
"Measures how much of the provided context is actually relevant to answering the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_recall_fn_def = ScoringFn(
identifier="braintrust::context-recall",
description=(
"Evaluates how well the context covers the information needed to answer the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -5,14 +5,23 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
factuality_fn_def = ScoringFn( factuality_fn_def = ScoringFn(
identifier="braintrust::factuality", identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", description=(
params=None, "Test output factuality against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="factuality", provider_resource_id="factuality",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -16,7 +16,12 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +30,10 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class LlmAsJudgeScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__( def __init__(
self, self,
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
@ -65,30 +73,17 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn): class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns A scoring_fn that assigns
""" """

View file

@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
QueryCondition,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
SpanWithStatus,
StructuredLogEvent,
Telemetry,
Trace,
UnstructuredLogEvent,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = { _GLOBAL_STORAGE = {
@ -100,8 +112,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
async def shutdown(self) -> None: async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush() trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent): if isinstance(event, UnstructuredLogEvent):

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.telemetry import * # noqa: F403
class SampleTelemetryImpl(Telemetry): class SampleTelemetryImpl(Telemetry):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .brave_search import BraveSearchToolRuntimeImpl
from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": args["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult(
content=content_items,
)
def _clean_brave_response(self, search_response):
clean_response = []
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][: self.config.max_results]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)
return clean_response
def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}
if r_type not in type_cleaners:
return ""
selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)
if isinstance(results, list):
cleaned = [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
else:
cleaned = {k: v for k, v in results.items() if k in selected_keys}
return str(cleaned)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
class BraveSearchToolConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Brave Search API Key",
)
max_results: int = Field(
default=3,
description="The maximum number of results to return",
)

View file

@ -6,7 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.kvstore import kvstore_dependencies from llama_stack.providers.utils.kvstore import kvstore_dependencies

View file

@ -6,7 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,8 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
META_REFERENCE_DEPS = [ META_REFERENCE_DEPS = [
"accelerate", "accelerate",
@ -149,6 +154,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -6,8 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
EMBEDDING_DEPS = [ EMBEDDING_DEPS = [
"blobfile", "blobfile",

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

Some files were not shown because too many files have changed in this diff Show more