mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Merge branch 'main' into vllm
This commit is contained in:
commit
73fede90a6
175 changed files with 7948 additions and 876 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv
|
||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721
|
||||
|
|
|
@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
|||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Groq | Hosted | | :heavy_check_mark: | | | |
|
||||
| Ollama | Single Node | | :heavy_check_mark: | | | |
|
||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||
|
@ -127,7 +128,7 @@ You have two ways to install this repository:
|
|||
conda activate stack
|
||||
|
||||
cd llama-stack
|
||||
$CONDA_PREFIX/bin/pip install -e .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
{
|
||||
"bedrock": [
|
||||
"hf-serverless": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"boto3",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
|
@ -11,6 +11,100 @@
|
|||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"together": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"vllm-gpu": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -63,7 +157,7 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"hf-endpoint": [
|
||||
"tgi": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
|
@ -96,11 +190,11 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"hf-serverless": [
|
||||
"aiohttp",
|
||||
"bedrock": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"boto3",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
|
@ -108,7 +202,6 @@
|
|||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -207,6 +300,34 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"cerebras": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"cerebras_cloud_sdk",
|
||||
"chardet",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"ollama": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
|
@ -240,7 +361,7 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"tgi": [
|
||||
"hf-endpoint": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
|
@ -272,126 +393,5 @@
|
|||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"together": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"vllm-gpu": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"cerebras": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"cerebras_cloud_sdk",
|
||||
"chardet",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
]
|
||||
}
|
||||
|
|
4636
docs/getting_started.ipynb
Normal file
4636
docs/getting_started.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -544,7 +544,7 @@
|
|||
" provider_type: inline::meta-reference\n",
|
||||
" inference:\n",
|
||||
" - config:\n",
|
||||
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n",
|
||||
" api_key: <...>\n",
|
||||
" url: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://api.together.xyz/v1</span>\n",
|
||||
" provider_id: together\n",
|
||||
" provider_type: remote::together\n",
|
||||
|
@ -663,7 +663,7 @@
|
|||
" provider_type: inline::meta-reference\n",
|
||||
" inference:\n",
|
||||
" - config:\n",
|
||||
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n",
|
||||
" api_key: <...>\n",
|
||||
" url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n",
|
||||
" provider_id: together\n",
|
||||
" provider_type: remote::together\n",
|
||||
|
|
|
@ -338,8 +338,8 @@ distribution_spec:
|
|||
inference: remote::ollama
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
agents: inline::meta-reference
|
||||
telemetry: inline::meta-reference
|
||||
image_type: conda
|
||||
```
|
||||
|
||||
|
|
|
@ -8,10 +8,6 @@ building_distro
|
|||
configuration
|
||||
```
|
||||
|
||||
<!-- self_hosted_distro/index -->
|
||||
<!-- remote_hosted_distro/index -->
|
||||
<!-- ondevice_distro/index -->
|
||||
|
||||
You can instantiate a Llama Stack in one of the following ways:
|
||||
- **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library)
|
||||
- **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container.
|
||||
|
@ -30,11 +26,15 @@ If so, we suggest:
|
|||
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
||||
|
||||
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
||||
- {dockerhub}`distribution-together` ([Guide](remote_hosted_distro/index))
|
||||
- {dockerhub}`distribution-fireworks` ([Guide](remote_hosted_distro/index))
|
||||
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
||||
- {dockerhub}`distribution-fireworks` ([Guide](self_hosted_distro/fireworks))
|
||||
|
||||
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
|
||||
- [iOS SDK](ondevice_distro/ios_sdk)
|
||||
- [Android](ondevice_distro/android_sdk)
|
||||
|
||||
- **Do you want a hosted Llama Stack endpoint?** If so, we suggest:
|
||||
- [Remote-Hosted Llama Stack Endpoints](remote_hosted_distro/index)
|
||||
|
||||
|
||||
You can also build your own [custom distribution](building_distro).
|
||||
|
|
|
@ -42,6 +42,7 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)`
|
||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)`
|
||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)`
|
||||
- `meta-llama/Llama-3.3-70B-Instruct (fireworks/llama-v3p3-70b-instruct)`
|
||||
- `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)`
|
||||
- `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)`
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-3.2-3B-Instruct`
|
||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
|
||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
|
||||
- `meta-llama/Llama-3.3-70B-Instruct`
|
||||
- `meta-llama/Llama-Guard-3-8B`
|
||||
- `meta-llama/Llama-Guard-3-11B-Vision`
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ Configuration for this is available at `distributions/ollama/run.yaml`.
|
|||
|
||||
### 3. Use the Llama Stack client SDK
|
||||
|
||||
You can interact with the Llama Stack server using various client SDKs. We will use the Python SDK which you can install using:
|
||||
You can interact with the Llama Stack server using various client SDKs. We will use the Python SDK which you can install using the following command. Note that you must be using Python 3.10 or newer:
|
||||
```bash
|
||||
pip install llama-stack-client
|
||||
```
|
||||
|
@ -51,7 +51,8 @@ pip install llama-stack-client
|
|||
Let's use the `llama-stack-client` CLI to check the connectivity to the server.
|
||||
|
||||
```bash
|
||||
llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT models list
|
||||
llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
||||
llama-stack-client models list
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
|
@ -61,7 +62,7 @@ llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT models list
|
|||
|
||||
You can test basic Llama inference completion using the CLI too.
|
||||
```bash
|
||||
llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT \
|
||||
llama-stack-client \
|
||||
inference chat-completion \
|
||||
--message "hello, what model are you?"
|
||||
```
|
||||
|
@ -153,10 +154,3 @@ if __name__ == "__main__":
|
|||
- Learn how to [Build Llama Stacks](../distributions/index.md)
|
||||
- See [References](../references/index.md) for more details about the llama CLI and Python SDK
|
||||
- For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository.
|
||||
|
||||
|
||||
## Thinking out aloud here in terms of what to write in the docs
|
||||
|
||||
- how to get a llama stack server running
|
||||
- what are all the different client sdks
|
||||
- what are the components of building agents
|
||||
|
|
|
@ -16,7 +16,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
|||
|
||||
##### Chatbot
|
||||
```{eval-rst}
|
||||
.. video:: https://github.com/user-attachments/assets/6ca617e8-32ca-49b2-9774-185020ff5204
|
||||
.. video:: https://github.com/user-attachments/assets/8d2ef802-5812-4a28-96e1-316038c84cbf
|
||||
:autoplay:
|
||||
:playsinline:
|
||||
:muted:
|
||||
|
|
|
@ -47,7 +47,7 @@ This first example walks you through how to evaluate a model candidate served by
|
|||
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||
|
||||
#### 1.1 Running MMMU
|
||||
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
|
||||
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in this [GitHub Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
|
||||
|
||||
```python
|
||||
import datasets
|
||||
|
|
|
@ -358,7 +358,7 @@
|
|||
" if not stream:\n",
|
||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
||||
" else:\n",
|
||||
" async for log in EventLogger().log(response):\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"# In a Jupyter Notebook cell, use `await` to call the function\n",
|
||||
|
@ -366,16 +366,6 @@
|
|||
"# To run it in a python file, use this line instead\n",
|
||||
"# asyncio.run(run_main())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "9399aecc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#fin"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -67,7 +67,7 @@
|
|||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"from llama_stack.distribution.datatypes import RemoteProviderConfig\n",
|
||||
"from llama_stack.apis.safety import * # noqa: F403\n",
|
||||
"from llama_stack.apis.safety import Safety\n",
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -127,7 +127,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
"version": "3.11.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -45,7 +45,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
|
||||
---
|
||||
|
||||
## Install Dependencies and Set Up Environment
|
||||
## Install Dependencies and Set Up Environmen
|
||||
|
||||
1. **Create a Conda Environment**:
|
||||
Create a new Conda environment with Python 3.10:
|
||||
|
@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
Open a new terminal and install `llama-stack`:
|
||||
```bash
|
||||
conda activate ollama
|
||||
pip install llama-stack==0.0.55
|
||||
pip install llama-stack==0.0.61
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
3. **Set the ENV variables by exporting them to the terminal**:
|
||||
```bash
|
||||
export OLLAMA_URL="http://localhost:11434"
|
||||
export LLAMA_STACK_PORT=5051
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
|
||||
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
|
||||
```
|
||||
|
@ -104,34 +104,29 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
3. **Run the Llama Stack**:
|
||||
Run the stack with command shared by the API from earlier:
|
||||
```bash
|
||||
llama stack run ollama \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
llama stack run ollama
|
||||
--port $LLAMA_STACK_PORT
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL
|
||||
--env OLLAMA_URL=$OLLAMA_URL
|
||||
```
|
||||
Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
||||
|
||||
The server will start and listen on `http://localhost:5051`.
|
||||
The server will start and listen on `http://localhost:5001`.
|
||||
|
||||
---
|
||||
## Test with `llama-stack-client` CLI
|
||||
After setting up the server, open a new terminal window and install the llama-stack-client package.
|
||||
After setting up the server, open a new terminal window and configure the llama-stack-client.
|
||||
|
||||
1. Install the llama-stack-client package
|
||||
1. Configure the CLI to point to the llama-stack server.
|
||||
```bash
|
||||
conda activate ollama
|
||||
pip install llama-stack-client
|
||||
```
|
||||
2. Configure the CLI to point to the llama-stack server.
|
||||
```bash
|
||||
llama-stack-client configure --endpoint http://localhost:5051
|
||||
llama-stack-client configure --endpoint http://localhost:5001
|
||||
```
|
||||
**Expected Output:**
|
||||
```bash
|
||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5051
|
||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
|
||||
```
|
||||
3. Test the CLI by running inference:
|
||||
2. Test the CLI by running inference:
|
||||
```bash
|
||||
llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon"
|
||||
```
|
||||
|
@ -153,16 +148,18 @@ After setting up the server, open a new terminal window and install the llama-st
|
|||
After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`:
|
||||
|
||||
```bash
|
||||
curl http://localhost:$LLAMA_STACK_PORT/inference/chat_completion \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Llama3.2-3B-Instruct",
|
||||
curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
|
||||
-H "Content-Type: application/json"
|
||||
-d @- <<EOF
|
||||
{
|
||||
"model_id": "$INFERENCE_MODEL",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write me a 2-sentence poem about the moon"}
|
||||
],
|
||||
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512}
|
||||
}'
|
||||
}
|
||||
EOF
|
||||
```
|
||||
|
||||
You can check the available models with the command `llama-stack-client models list`.
|
||||
|
@ -186,16 +183,12 @@ You can check the available models with the command `llama-stack-client models l
|
|||
|
||||
You can also interact with the Llama Stack server using a simple Python script. Below is an example:
|
||||
|
||||
### 1. Activate Conda Environment and Install Required Python Packages
|
||||
The `llama-stack-client` library offers a robust and efficient python methods for interacting with the Llama Stack server.
|
||||
### 1. Activate Conda Environmen
|
||||
|
||||
```bash
|
||||
conda activate ollama
|
||||
pip install llama-stack-client
|
||||
```
|
||||
|
||||
Note, the client library gets installed by default if you install the server library
|
||||
|
||||
### 2. Create Python Script (`test_llama_stack.py`)
|
||||
```bash
|
||||
touch test_llama_stack.py
|
||||
|
@ -206,19 +199,28 @@ touch test_llama_stack.py
|
|||
In `test_llama_stack.py`, write the following code:
|
||||
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
import os
|
||||
from llama_stack_client import LlamaStackClien
|
||||
|
||||
# Initialize the client
|
||||
client = LlamaStackClient(base_url="http://localhost:5051")
|
||||
# Get the model ID from the environment variable
|
||||
INFERENCE_MODEL = os.environ.get("INFERENCE_MODEL")
|
||||
|
||||
# Create a chat completion request
|
||||
# Check if the environment variable is se
|
||||
if INFERENCE_MODEL is None:
|
||||
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
|
||||
|
||||
# Initialize the clien
|
||||
client = LlamaStackClient(base_url="http://localhost:5001")
|
||||
|
||||
# Create a chat completion reques
|
||||
response = client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a friendly assistant."},
|
||||
{"role": "user", "content": "Write a two-sentence poem about llama."}
|
||||
],
|
||||
model_id=MODEL_NAME,
|
||||
model_id=INFERENCE_MODEL,
|
||||
)
|
||||
|
||||
# Print the response
|
||||
print(response.completion_message.content)
|
||||
```
|
||||
|
|
|
@ -18,18 +18,30 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolCallDelta,
|
||||
ToolChoice,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||
|
||||
from llama_stack.apis.inference import ToolResponseMessage
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
|
|
|
@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -27,6 +28,12 @@ class _URLOrData(BaseModel):
|
|||
return values
|
||||
return {"url": values}
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(_URLOrData):
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,18 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
from enum import Enum
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
|
@ -32,8 +34,9 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
|
|
@ -7,17 +7,17 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -58,6 +58,7 @@ class TrainingConfig(BaseModel):
|
|||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
|
|
|
@ -18,6 +18,8 @@ class ResourceType(Enum):
|
|||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
tool = "tool"
|
||||
tool_group = "tool_group"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
|
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
|
||||
# mapping of metric to value
|
||||
|
@ -48,7 +47,7 @@ class Scoring(Protocol):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
|
@ -56,5 +55,5 @@ class Scoring(Protocol):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse: ...
|
||||
|
|
|
@ -6,13 +6,12 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
|
||||
|
|
7
llama_stack/apis/tools/__init__.py
Normal file
7
llama_stack/apis/tools/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
141
llama_stack/apis/tools/tools.py
Normal file
141
llama_stack/apis/tools/tools.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Dict[str, Any]
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroup",
|
||||
)
|
||||
|
||||
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
"""List tool groups with optional provider"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/list", method="GET")
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
"""List tools with optional tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/get", method="GET")
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
|
||||
@webmethod(route="/toolgroups/unregister", method="POST")
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
"""Unregister a tool group"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
|
@ -6,11 +6,12 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.api.datatypes import SamplingParams
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
|
|
@ -3,21 +3,28 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import os
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
@ -100,7 +107,7 @@ class StackBuild(Subcommand):
|
|||
build_config.image_type = args.image_type
|
||||
else:
|
||||
self.parser.error(
|
||||
f"Please specify a image-type (docker | conda) for {args.template}"
|
||||
f"Please specify a image-type (docker | conda | venv) for {args.template}"
|
||||
)
|
||||
self._run_stack_build_command_from_build_config(
|
||||
build_config, template_name=args.template
|
||||
|
@ -122,7 +129,7 @@ class StackBuild(Subcommand):
|
|||
)
|
||||
|
||||
image_type = prompt(
|
||||
"> Enter the image type you want your Llama Stack to be built as (docker or conda): ",
|
||||
"> Enter the image type you want your Llama Stack to be built as (docker or conda or venv): ",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in ["docker", "conda", "venv"],
|
||||
error_message="Invalid image type, please enter conda or docker or venv",
|
||||
|
|
|
@ -6,21 +6,22 @@
|
|||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import pkg_resources
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--templat
|
|||
|
||||
EOF
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile\n\n"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
||||
|
|
|
@ -6,10 +6,14 @@
|
|||
import logging
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
DistributionSpec,
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
|
@ -17,10 +21,7 @@ from llama_stack.distribution.distribution import (
|
|||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -4,23 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
|
@ -37,6 +38,8 @@ RoutableObject = Union[
|
|||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
]
|
||||
|
||||
|
||||
|
@ -48,6 +51,8 @@ RoutableObjectWithProvider = Annotated[
|
|||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
@ -59,6 +64,7 @@ RoutedProtocol = Union[
|
|||
DatasetIO,
|
||||
Scoring,
|
||||
Eval,
|
||||
ToolRuntime,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.eval_tasks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
|
@ -16,7 +17,6 @@ from pathlib import Path
|
|||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import (
|
||||
APIResponse,
|
||||
|
@ -28,7 +28,6 @@ from llama_stack_client import (
|
|||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
|
@ -39,9 +38,9 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
@ -67,6 +66,7 @@ def in_notebook():
|
|||
def stream_across_asyncio_run_boundary(
|
||||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
path: Optional[str] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
@ -74,6 +74,7 @@ def stream_across_asyncio_run_boundary(
|
|||
async def consumer():
|
||||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
async for item in await gen:
|
||||
result_queue.put(item)
|
||||
|
@ -85,6 +86,7 @@ def stream_across_asyncio_run_boundary(
|
|||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
stop_event.set()
|
||||
await end_trace()
|
||||
|
||||
def run_async():
|
||||
# Run our own loop to avoid double async generator cleanup which is done
|
||||
|
@ -170,6 +172,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -177,23 +180,56 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
config_path_or_template_name, custom_provider_registry
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
def _get_path(
|
||||
self,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
return options.url
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
path = self._get_path(*args, **kwargs)
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
path=path,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||
|
||||
async def _traced_request():
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
return await self.async_client.request(*args, **kwargs)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return asyncio.run(_traced_request())
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
|
@ -206,7 +242,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
os.environ["TELEMETRY_SINKS"] = "sqlite"
|
||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(
|
||||
sink for sink in current_sinks if sink != "console"
|
||||
)
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
|
@ -247,7 +286,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
||||
# Redact sensitive information before printing
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
@ -295,41 +337,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
result = await func(**body)
|
||||
body = self._convert_body(path, body)
|
||||
result = await func(**body)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
|
||||
async def _call_streaming(
|
||||
self,
|
||||
|
@ -341,51 +379,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
body = self._convert_body(path, body)
|
||||
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
|
|
|
@ -6,14 +6,10 @@
|
|||
import importlib
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -30,11 +26,34 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
DatasetsProtocolPrivate,
|
||||
EvalTasksProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
MemoryBanksProtocolPrivate,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
RemoteProviderSpec,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolsProtocolPrivate,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -60,12 +79,15 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.eval: Eval,
|
||||
Api.eval_tasks: EvalTasks,
|
||||
Api.post_training: PostTraining,
|
||||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
|
@ -17,6 +18,7 @@ from .routing_tables import (
|
|||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
ToolGroupsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,6 +35,7 @@ async def get_routing_table_impl(
|
|||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"eval_tasks": EvalTasksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
|
@ -51,6 +54,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
MemoryRouter,
|
||||
SafetyRouter,
|
||||
ScoringRouter,
|
||||
ToolRuntimeRouter,
|
||||
)
|
||||
|
||||
api_to_routers = {
|
||||
|
@ -60,6 +64,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"datasetio": DatasetIORouter,
|
||||
"scoring": ScoringRouter,
|
||||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
|
@ -6,15 +6,40 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio.datasetio import DatasetIO
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
AppEvalTaskConfig,
|
||||
Eval,
|
||||
EvalTaskConfig,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks.memory_banks import BankParams
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
@ -329,7 +354,6 @@ class EvalRouter(Eval):
|
|||
task_config=task_config,
|
||||
)
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
|
@ -372,3 +396,28 @@ class EvalRouter(Eval):
|
|||
task_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
)
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
tool_group.name
|
||||
).discover_tools(tool_group)
|
||||
|
|
|
@ -8,19 +8,40 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
|
||||
from llama_stack.apis.memory_banks import (
|
||||
BankParams,
|
||||
MemoryBank,
|
||||
MemoryBanks,
|
||||
MemoryBankType,
|
||||
)
|
||||
from llama_stack.apis.models import Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
|
@ -45,6 +66,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
return await p.register_eval_task(obj)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.register_tool(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -57,6 +80,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.unregister_tool(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
@ -104,6 +129,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
p.tool_store = self
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
@ -125,6 +152,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, EvalTasksRoutingTable):
|
||||
return ("Eval", "eval_task")
|
||||
elif isinstance(self, ToolGroupsRoutingTable):
|
||||
return ("Tools", "tool")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
|
@ -461,3 +490,88 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
|||
provider_resource_id=provider_eval_task_id,
|
||||
)
|
||||
await self.register_object(eval_task)
|
||||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if tool_group_id:
|
||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
||||
return tools
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return await self.get_all_with_type("tool_group")
|
||||
|
||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = []
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id.keys()) > 1:
|
||||
raise ValueError(
|
||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
||||
)
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
if isinstance(tool_group, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
)
|
||||
|
||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
if existing_tool:
|
||||
existing_dict = existing_tool.model_dump()
|
||||
new_dict = tool.model_dump()
|
||||
|
||||
if existing_dict != new_dict:
|
||||
raise ValueError(
|
||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||
)
|
||||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
identifier=tool_group_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_group_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
tool_group = await self.get_tool_group(tool_group_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {tool_group_id} not found")
|
||||
tools = await self.list_tools(tool_group_id)
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
|
@ -28,25 +28,29 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
@ -235,7 +239,12 @@ def main():
|
|||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMASTACK_PORT", 5000)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
|
||||
)
|
||||
|
@ -277,7 +286,8 @@ def main():
|
|||
config = StackRunConfig(**config)
|
||||
|
||||
print("Run configuration:")
|
||||
print(yaml.dump(config.model_dump(), indent=2))
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
|
|
@ -8,32 +8,31 @@ import logging
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
@ -113,6 +112,26 @@ class EnvVarError(Exception):
|
|||
)
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
|
|
@ -90,7 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
|
|||
$env_vars \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$mounts \
|
||||
$docker_image:$version_tag \
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml-config /app/config.yaml \
|
||||
--port "$port"
|
||||
--env LLAMASTACK_PORT=$port \
|
||||
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
|
||||
$docker_image:$version_tag
|
||||
|
|
|
@ -13,11 +13,8 @@ import pydantic
|
|||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
KVStore,
|
||||
kvstore_impl,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
|
|
|
@ -8,11 +8,14 @@ import os
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from llama_stack.distribution.store import * # noqa F403
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.distribution.datatypes import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -129,7 +129,7 @@ def application_evaluation_page():
|
|||
|
||||
# Display current row results using separate containers
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i+1}/{len(rows)})"
|
||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
||||
)
|
||||
results_container.json(
|
||||
score_res.to_json(),
|
||||
|
|
|
@ -232,7 +232,7 @@ def run_evaluation_3():
|
|||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i+1}/{len(rows)})"
|
||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
||||
)
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -29,6 +30,7 @@ class Api(Enum):
|
|||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
tool_runtime = "tool_runtime"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
@ -38,6 +40,7 @@ class Api(Enum):
|
|||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
eval_tasks = "eval_tasks"
|
||||
tool_groups = "tool_groups"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
@ -75,6 +78,12 @@ class EvalTasksProtocolPrivate(Protocol):
|
|||
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
|
||||
|
||||
|
||||
class ToolsProtocolPrivate(Protocol):
|
||||
async def register_tool(self, tool: Tool) -> None: ...
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
|
|
|
@ -13,19 +13,64 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
InferenceStep,
|
||||
MemoryRetrievalStep,
|
||||
MemoryToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
WolframAlphaToolDefinition,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
|
@ -539,7 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_call = message.tool_calls[0]
|
||||
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool):
|
||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
|
|
|
@ -9,15 +9,26 @@ import logging
|
|||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentTurnCreateRequest,
|
||||
Attachment,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
|
|
|
@ -10,9 +10,11 @@ import uuid
|
|||
from datetime import datetime
|
||||
|
||||
from typing import List, Optional
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
|
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
|
|||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
@ -64,7 +62,7 @@ async def llm_rag_query_generator(
|
|||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model=model,
|
||||
model_id=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,9 @@ import logging
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
|
||||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import * # noqa: F401, F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel): ...
|
||||
|
|
|
@ -3,18 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
||||
|
|
|
@ -3,37 +3,38 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.agents import Agents, StepType
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "eval_tasks:"
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
input_query = "input_query"
|
||||
expected_answer = "expected_answer"
|
||||
chat_completion_input = "chat_completion_input"
|
||||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||
class MetaReferenceEvalImpl(
|
||||
Eval,
|
||||
EvalTasksProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
|
@ -77,29 +78,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
)
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
expected_schemas = [
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
]
|
||||
|
||||
if dataset_def.dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
|
@ -109,8 +87,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
|
||||
)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(
|
||||
|
@ -162,11 +142,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
)
|
||||
]
|
||||
final_event = turn_response[-1].event.payload
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: final_event.turn.output_message.content
|
||||
}
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
memory_rag_context = None
|
||||
for step in final_event.turn.steps:
|
||||
if step.step_type == StepType.memory_retrieval.value:
|
||||
memory_rag_context = " ".join(x.text for x in step.inserted_context)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = (
|
||||
final_event.turn.output_message.content
|
||||
)
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
generations.append(agent_generation)
|
||||
|
||||
return generations
|
||||
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
|
|
|
@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
MetaReferenceInferenceConfig,
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
|
@ -27,9 +30,9 @@ class ModelRunner:
|
|||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, req: Any):
|
||||
if isinstance(req, ChatCompletionRequest):
|
||||
if isinstance(req, ChatCompletionRequestWithRawContent):
|
||||
return self.llama.chat_completion(req)
|
||||
elif isinstance(req, CompletionRequest):
|
||||
elif isinstance(req, CompletionRequestWithRawContent):
|
||||
return self.llama.completion(req)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
|
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
|
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
|
|
|
@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
|
|||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
@ -264,9 +267,6 @@ def launch_dist_group(
|
|||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
id = uuid.uuid4().hex
|
||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
|
@ -315,7 +315,7 @@ def start_model_parallel_process(
|
|||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
_response = request_socket.recv()
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
|
|||
self.started = False
|
||||
|
||||
def run_inference(
|
||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||
self,
|
||||
req: Union[
|
||||
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
|
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
|
|
|
@ -16,11 +16,14 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.apis.memory import (
|
||||
Chunk,
|
||||
Memory,
|
||||
MemoryBankDocument,
|
||||
QueryDocumentsResponse,
|
||||
)
|
||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
|
@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
|
|||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
|
|
|
@ -14,14 +14,16 @@ from enum import Enum
|
|||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from pydantic import BaseModel
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
|
@ -48,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
"Llama3.1-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_1_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
|
|
|
@ -3,11 +3,26 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
@ -14,27 +15,33 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
|
@ -43,11 +50,12 @@ from torchtune.modules.peft import (
|
|||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
load_dora_magnitudes,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -110,6 +118,10 @@ class LoraFinetuningSingleDevice:
|
|||
self.checkpoint_dir = config.checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
||||
)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
@ -125,6 +137,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
|
@ -277,7 +290,6 @@ class LoraFinetuningSingleDevice:
|
|||
for m in model.modules():
|
||||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
load_dora_magnitudes(model)
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(
|
||||
lora_weights_state_dict, strict=False
|
||||
|
@ -572,7 +584,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Starting validation...")
|
||||
pbar = tqdm(total=len(self._validation_dataloader))
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == 10:
|
||||
if idx == self.max_validation_steps:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
|
|
|
@ -7,8 +7,14 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
|
|
@ -9,10 +9,24 @@ import re
|
|||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.datatypes import Role
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
Inference,
|
||||
Message,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
|
|
@ -11,11 +11,16 @@ import torch
|
|||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
|
@ -3,16 +3,24 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
|
@ -21,7 +29,10 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
|||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
|
||||
|
||||
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
class BasicScoringImpl(
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: BasicScoringConfig,
|
||||
|
@ -58,30 +69,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
|
|
|
@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
|
|||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.equality import equality
|
||||
|
||||
|
||||
class EqualityScoringFn(BaseScoringFn):
|
||||
class EqualityScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
|
|
@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserScoringFn(BaseScoringFn):
|
||||
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.subset_of import subset_of
|
||||
|
||||
|
||||
class SubsetOfScoringFn(BaseScoringFn):
|
||||
class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||
"""
|
||||
|
|
|
@ -3,32 +3,115 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from autoevals.ragas import (
|
||||
AnswerCorrectness,
|
||||
AnswerRelevancy,
|
||||
AnswerSimilarity,
|
||||
ContextEntityRecall,
|
||||
ContextPrecision,
|
||||
ContextRecall,
|
||||
ContextRelevancy,
|
||||
Faithfulness,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
ScoringResultRow,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
validate_row_schema,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
from .config import BraintrustScoringConfig
|
||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
|
||||
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
|
||||
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
|
||||
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
|
||||
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
|
||||
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
|
||||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
|
||||
|
||||
|
||||
class BraintrustScoringFnEntry(BaseModel):
|
||||
identifier: str
|
||||
evaluator: Any
|
||||
fn_def: ScoringFn
|
||||
|
||||
|
||||
SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::factuality",
|
||||
evaluator=Factuality(),
|
||||
fn_def=factuality_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::answer-correctness",
|
||||
evaluator=AnswerCorrectness(),
|
||||
fn_def=answer_correctness_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::answer-relevancy",
|
||||
evaluator=AnswerRelevancy(),
|
||||
fn_def=answer_relevancy_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::answer-similarity",
|
||||
evaluator=AnswerSimilarity(),
|
||||
fn_def=answer_similarity_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::faithfulness",
|
||||
evaluator=Faithfulness(),
|
||||
fn_def=faithfulness_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-entity-recall",
|
||||
evaluator=ContextEntityRecall(),
|
||||
fn_def=context_entity_recall_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-precision",
|
||||
evaluator=ContextPrecision(),
|
||||
fn_def=context_precision_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-recall",
|
||||
evaluator=ContextRecall(),
|
||||
fn_def=context_recall_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-relevancy",
|
||||
evaluator=ContextRelevancy(),
|
||||
fn_def=context_relevancy_fn_def,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BraintrustScoringImpl(
|
||||
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -41,12 +124,12 @@ class BraintrustScoringImpl(
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
self.braintrust_evaluators = {
|
||||
"braintrust::factuality": Factuality(),
|
||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||
entry.identifier: entry.evaluator
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
factuality_fn_def.identifier: factuality_fn_def,
|
||||
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
||||
entry.identifier: entry.fn_def
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -67,23 +150,6 @@ class BraintrustScoringImpl(
|
|||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
if not self.config.openai_api_key:
|
||||
|
@ -99,11 +165,16 @@ class BraintrustScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
|
@ -123,6 +194,7 @@ class BraintrustScoringImpl(
|
|||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
@ -130,12 +202,19 @@ class BraintrustScoringImpl(
|
|||
input_query = input_row["input_query"]
|
||||
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||
|
||||
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||
result = evaluator(
|
||||
generated_answer,
|
||||
expected_answer,
|
||||
input=input_query,
|
||||
context=input_row["context"] if "context" in input_row else None,
|
||||
)
|
||||
score = result.score
|
||||
return {"score": score, "metadata": result.metadata}
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
|
@ -147,8 +226,17 @@ class BraintrustScoringImpl(
|
|||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
aggregation_functions = [AggregationFunctionType.average]
|
||||
agg_results = aggregate_average(score_results)
|
||||
aggregation_functions = self.supported_fn_defs_registry[
|
||||
scoring_fn_id
|
||||
].params.aggregation_functions
|
||||
|
||||
# override scoring_fn params if provided
|
||||
if scoring_functions[scoring_fn_id] is not None:
|
||||
override_params = scoring_functions[scoring_fn_id]
|
||||
if override_params.aggregation_functions:
|
||||
aggregation_functions = override_params.aggregation_functions
|
||||
|
||||
agg_results = aggregate_metrics(score_results, aggregation_functions)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
|
|
|
@ -5,14 +5,23 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
answer_correctness_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
params=None,
|
||||
description=(
|
||||
"Scores the correctness of the answer based on the ground truth. "
|
||||
"Uses Braintrust LLM-based scorer from autoevals library."
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
answer_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-relevancy",
|
||||
description=(
|
||||
"Test output relevancy against the input query using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
answer_similarity_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-similarity",
|
||||
description=(
|
||||
"Test output similarity against expected value using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-similarity",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_entity_recall_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-entity-recall",
|
||||
description=(
|
||||
"Evaluates how well the context captures the named entities present in the "
|
||||
"reference answer. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-entity-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_precision_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-precision",
|
||||
description=(
|
||||
"Measures how much of the provided context is actually relevant to answering the "
|
||||
"question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-precision",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_recall_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-recall",
|
||||
description=(
|
||||
"Evaluates how well the context covers the information needed to answer the "
|
||||
"question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-relevancy",
|
||||
description=(
|
||||
"Assesses how relevant the provided context is to the given question. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -5,14 +5,23 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
factuality_fn_def = ScoringFn(
|
||||
identifier="braintrust::factuality",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
params=None,
|
||||
description=(
|
||||
"Test output factuality against expected value using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="factuality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
faithfulness_fn_def = ScoringFn(
|
||||
identifier="braintrust::faithfulness",
|
||||
description=(
|
||||
"Test output faithfulness to the input query using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="faithfulness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -16,7 +16,12 @@ from llama_stack.apis.scoring import (
|
|||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
|
@ -25,7 +30,10 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
|||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
|
||||
|
||||
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
class LlmAsJudgeScoringImpl(
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
|
@ -65,30 +73,17 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
|
|
|
@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference
|
|||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
||||
|
||||
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||
|
||||
|
||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns
|
||||
"""
|
||||
|
|
|
@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider
|
|||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
MetricEvent,
|
||||
QueryCondition,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
SpanWithStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
|||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
|
@ -100,8 +112,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
|
|
|
@ -4,12 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .brave_search import BraveSearchToolRuntimeImpl
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||
impl = BraveSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
)
|
||||
|
||||
def _clean_brave_response(self, search_response):
|
||||
clean_response = []
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][: self.config.max_results]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return clean_response
|
||||
|
||||
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||
type_cleaners = {
|
||||
"web": (
|
||||
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||
"infobox": (
|
||||
["type", "title", "url", "description", "long_desc"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||
"locations": (
|
||||
[
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
],
|
||||
lambda x: x,
|
||||
),
|
||||
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||
}
|
||||
|
||||
if r_type not in type_cleaners:
|
||||
return ""
|
||||
|
||||
selected_keys, result_selector = type_cleaners[r_type]
|
||||
results = result_selector(results)
|
||||
|
||||
if isinstance(results, list):
|
||||
cleaned = [
|
||||
{k: v for k, v in item.items() if k in selected_keys}
|
||||
for item in results
|
||||
]
|
||||
else:
|
||||
cleaned = {k: v for k, v in results.items() if k in selected_keys}
|
||||
|
||||
return str(cleaned)
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraveSearchToolConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Brave Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
|
@ -6,7 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
|
|
@ -6,8 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
"accelerate",
|
||||
|
@ -149,6 +154,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["groq"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -6,8 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
EMBEDDING_DEPS = [
|
||||
"blobfile",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue