mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
3cf7b92063
189 changed files with 52233 additions and 28035 deletions
9
.cursor/rules/general.mdc
Normal file
9
.cursor/rules/general.mdc
Normal file
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
description: General rules always applicable across the project
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
# Style
|
||||
|
||||
- Comments must add value to code. Don't write filler comments explaining what you are doing next; they just add noise.
|
||||
- Add a comment to clarify surprising behavior which would not be obvious. Good variable naming and clear code organization is more important.
|
8
.github/dependabot.yml
vendored
Normal file
8
.github/dependabot.yml
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
# GitHub Dependabot configuration
|
||||
version: 2
|
||||
updates:
|
||||
# Enable version updates for GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/" # Will use the default workflow location of `.github/workflows`
|
||||
schedule:
|
||||
interval: "daily"
|
|
@ -310,7 +310,7 @@ jobs:
|
|||
- name: "PR - Upload Test Summary"
|
||||
id: pr_test_summary_upload
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-summary
|
||||
path: test-summary.md
|
||||
|
@ -320,7 +320,7 @@ jobs:
|
|||
- name: "PR - Update comment"
|
||||
id: pr_update_comment
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: thollander/actions-comment-pull-request@v2
|
||||
uses: thollander/actions-comment-pull-request@v3
|
||||
with:
|
||||
filePath: test-summary.md
|
||||
|
||||
|
|
36
.github/workflows/unit-tests.yml
vendored
Normal file
36
.github/workflows/unit-tests.yml
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
name: Unit Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
enable-cache: false
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
uv run -p 3.10 --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report.xml
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: |
|
||||
.pytest_cache/
|
||||
pytest-report.xml
|
||||
retention-days: 7
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -12,12 +12,14 @@ on:
|
|||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
|
||||
jobs:
|
||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -20,3 +20,5 @@ _build
|
|||
docs/src
|
||||
pyrightconfig.json
|
||||
venv/
|
||||
pytest-report.xml
|
||||
.coverage
|
||||
|
|
0
.gitmodules
vendored
0
.gitmodules
vendored
|
@ -15,10 +15,6 @@ repos:
|
|||
- id: end-of-file-fixer
|
||||
exclude: '^(.*\.svg)$'
|
||||
|
||||
# Temporarily disabling this
|
||||
# - id: no-commit-to-branch
|
||||
# args: ['--branch=main']
|
||||
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
rev: v1.5.4
|
||||
hooks:
|
||||
|
@ -68,12 +64,6 @@ repos:
|
|||
- pydantic
|
||||
pass_filenames: false
|
||||
|
||||
# - repo: https://github.com/jsh9/pydoclint
|
||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||
# hooks:
|
||||
# - id: pydoclint
|
||||
# args: [--config=pyproject.toml]
|
||||
|
||||
# - repo: https://github.com/tcort/markdown-link-check
|
||||
# rev: v3.11.2
|
||||
# hooks:
|
||||
|
|
304
CHANGELOG.md
Normal file
304
CHANGELOG.md
Normal file
|
@ -0,0 +1,304 @@
|
|||
# Changelog
|
||||
|
||||
# v0.1.6
|
||||
Published on: 2025-03-08T04:35:08Z
|
||||
|
||||
## 0.1.6 Release Notes
|
||||
|
||||
### Build and Test Agents
|
||||
* Inference: Fixed support for inline vllm provider
|
||||
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
||||
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
||||
* Agent: Unify tools and Python SDK Agents API
|
||||
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
||||
* Agent: Support python functions without @client_tool decorator as client tools
|
||||
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
||||
* VectorIO: MilvusDB support added
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
||||
* Eval: Documentation for eval, scoring, adding new benchmarks
|
||||
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
||||
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
||||
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
||||
|
||||
### Deploy and Monitoring of Agents
|
||||
* Better support for different log levels across all components for better monitoring
|
||||
|
||||
### Better Engineering
|
||||
* Enhance OpenAPI spec to include Error types across all APIs
|
||||
* Moved all tests to /tests and created unit tests to run on each PR
|
||||
* Removed all dependencies on llama-models repo
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.5.1
|
||||
Published on: 2025-02-28T22:37:44Z
|
||||
|
||||
## 0.1.5.1 Release Notes
|
||||
* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328
|
||||
|
||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1
|
||||
|
||||
---
|
||||
|
||||
# v0.1.5
|
||||
Published on: 2025-02-28T18:14:01Z
|
||||
|
||||
## 0.1.5 Release Notes
|
||||
### Build Agents
|
||||
* Inference: Support more non-llama models (openai, anthropic, gemini)
|
||||
* Inference: Can use the provider's model name in addition to the HF alias
|
||||
* Inference: Fixed issues with calling tools that weren't specified in the prompt
|
||||
* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling
|
||||
* Embeddings: Added support for Nemo retriever embedding models
|
||||
* Tools: Added support for MCP tools in Ollama Distribution
|
||||
* Distributions: Added new Groq distribution
|
||||
|
||||
### Customize Models
|
||||
* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model
|
||||
|
||||
### Monitor agents
|
||||
* More comprehensive logging of agent steps including client tools
|
||||
* Telemetry inputs/outputs are now structured and queryable
|
||||
* Ability to retrieve agents session, turn, step by ids
|
||||
|
||||
### Better Engineering
|
||||
* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin
|
||||
* Move most logging to use logger instead of prints
|
||||
* Completed text /chat-completion and /completion tests
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.4
|
||||
Published on: 2025-02-25T00:02:43Z
|
||||
|
||||
## v0.1.4 Release Notes
|
||||
Here are the key changes coming as part of this release:
|
||||
|
||||
### Build and Test Agents
|
||||
* Inference: Added support for non-llama models
|
||||
* Inference: Added option to list all downloaded models and remove models
|
||||
* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn
|
||||
* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides
|
||||
* Agent: Added logging for agent step start and completion times
|
||||
* Agent: Added support for logging for tool execution metadata
|
||||
* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs
|
||||
* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults
|
||||
* VectorIO: Improved performance of sqlite-vec using chunked writes
|
||||
### Agent Evals and Model Customization
|
||||
* Deprecated api /eval-tasks. Use /eval/benchmark instead
|
||||
* Added CPU training support for TorchTune
|
||||
### Deploy and Monitoring of Agents
|
||||
* Consistent view of client and server tool calls in telemetry
|
||||
### Better Engineering
|
||||
* Made tests more data-driven for consistent evaluation
|
||||
* Fixed documentation links and improved API reference generation
|
||||
* Various small fixes for build scripts and system reliability
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.3
|
||||
Published on: 2025-02-14T20:24:32Z
|
||||
|
||||
## v0.1.3 Release
|
||||
|
||||
Here are some key changes that are coming as part of this release.
|
||||
|
||||
### Build and Test Agents
|
||||
Streamlined the initial development experience
|
||||
- Added support for llama stack run --image-type venv
|
||||
- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration
|
||||
- vLLM improvements for tool calling and logprobs
|
||||
- Better handling of sporadic code_interpreter tool calls
|
||||
|
||||
### Agent Evals
|
||||
Better benchmarking and Agent performance assessment
|
||||
- Renamed eval API /eval-task to /benchmarks
|
||||
- Improved documentation and notebooks for RAG and evals
|
||||
|
||||
### Deploy and Monitoring of Agents
|
||||
Improved production readiness
|
||||
- Added usage metrics collection for chat completions
|
||||
- CLI improvements for provider information
|
||||
- Improved error handling and system reliability
|
||||
- Better model endpoint handling and accessibility
|
||||
- Improved signal handling on distro server
|
||||
|
||||
### Better Engineering
|
||||
Infrastructure and code quality improvements
|
||||
- Faster text-based chat completion tests
|
||||
- Improved testing for non-streaming agent apis
|
||||
- Standardized import formatting with ruff linter
|
||||
- Added conventional commits standard
|
||||
- Fixed documentation parsing issues
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.2
|
||||
Published on: 2025-02-07T22:06:49Z
|
||||
|
||||
# TL;DR
|
||||
- Several stabilizations to development flows after the switch to `uv`
|
||||
- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops)
|
||||
- Added automated rebuilds for ReadTheDocs
|
||||
- Llama Stack server supports HTTPS
|
||||
- Added system prompt overrides support
|
||||
- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan )
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.1
|
||||
Published on: 2025-02-02T02:29:24Z
|
||||
|
||||
A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements.
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.0
|
||||
Published on: 2025-01-24T17:47:47Z
|
||||
|
||||
We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions.
|
||||
|
||||
## Context
|
||||
GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open.
|
||||
|
||||
Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety.
|
||||
|
||||
With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience.
|
||||
|
||||
## Release
|
||||
After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements.
|
||||
|
||||
There are example standalone apps in llama-stack-apps.
|
||||
|
||||
|
||||
## Key Features of this release
|
||||
|
||||
- **Unified API Layer**
|
||||
- Inference: Run LLM models
|
||||
- RAG: Store and retrieve knowledge for RAG
|
||||
- Agents: Build multi-step agentic workflows
|
||||
- Tools: Register tools that can be called by the agent
|
||||
- Safety: Apply content filtering and safety policies
|
||||
- Evaluation: Test model and agent quality
|
||||
- Telemetry: Collect and analyze usage data and complex agentic traces
|
||||
- Post Training ( Coming Soon ): Fine tune models for specific use cases
|
||||
|
||||
- **Rich Provider Ecosystem**
|
||||
- Local Development: Meta's Reference, Ollama
|
||||
- Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras
|
||||
- On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI
|
||||
- On-device: iOS and Android support
|
||||
|
||||
- **Built for Production**
|
||||
- Pre-packaged distributions for common deployment scenarios
|
||||
- Backwards compatibility across model versions
|
||||
- Comprehensive evaluation capabilities
|
||||
- Full observability and monitoring
|
||||
|
||||
- **Multiple developer interfaces**
|
||||
- CLI: Command line interface
|
||||
- Python SDK
|
||||
- Swift iOS SDK
|
||||
- Kotlin Android SDK
|
||||
|
||||
- **Sample llama stack applications**
|
||||
- Python
|
||||
- iOS
|
||||
- Android
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.0rc12
|
||||
Published on: 2025-01-22T22:24:01Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.63
|
||||
Published on: 2024-12-18T07:17:43Z
|
||||
|
||||
A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially.
|
||||
|
||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63
|
||||
|
||||
---
|
||||
|
||||
# v0.0.62
|
||||
Published on: 2024-12-18T02:39:43Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.61
|
||||
Published on: 2024-12-10T20:50:33Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.55
|
||||
Published on: 2024-11-23T17:14:07Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.54
|
||||
Published on: 2024-11-22T00:36:09Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.53
|
||||
Published on: 2024-11-20T22:18:00Z
|
||||
|
||||
🚀 Initial Release Notes for Llama Stack!
|
||||
|
||||
### Added
|
||||
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
|
||||
- Persistence for registered objects with distribution
|
||||
- Ability to persist memory banks created for FAISS
|
||||
- PostgreSQL KVStore implementation
|
||||
- Environment variable placeholder support in run.yaml files
|
||||
- Comprehensive Zero-to-Hero notebooks and quickstart guides
|
||||
- Support for quantized models in Ollama
|
||||
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
|
||||
- Bedrock distribution with safety shields support
|
||||
- Evals API with task registration and scoring functions
|
||||
- MMLU and SimpleQA benchmark scoring functions
|
||||
- Huggingface dataset provider integration for benchmarks
|
||||
- Support for custom dataset registration from local paths
|
||||
- Benchmark evaluation CLI tools with visualization tables
|
||||
- RAG evaluation scoring functions and metrics
|
||||
- Local persistence for datasets and eval tasks
|
||||
|
||||
### Changed
|
||||
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
|
||||
- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`)
|
||||
- Updated API signatures for dataset and eval task registration
|
||||
- Restructured folder organization for providers
|
||||
- Enhanced Docker build configuration
|
||||
- Added version prefixing for REST API routes
|
||||
- Enhanced evaluation task registration workflow
|
||||
- Improved benchmark evaluation output formatting
|
||||
- Restructured evals folder organization for better modularity
|
||||
|
||||
### Removed
|
||||
- `llama stack configure` command
|
||||
|
||||
|
||||
---
|
|
@ -64,10 +64,10 @@ You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting
|
|||
You can install the dependencies by running:
|
||||
|
||||
```bash
|
||||
$ cd llama-stack
|
||||
$ uv sync --extra dev
|
||||
$ uv pip install -e .
|
||||
$ source .venv/bin/activate
|
||||
cd llama-stack
|
||||
uv sync --extra dev
|
||||
uv pip install -e .
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||
|
@ -80,7 +80,7 @@ LLAMA_STACK_CONFIG=
|
|||
|
||||
And then use this dotenv file when running client SDK tests via the following:
|
||||
```bash
|
||||
$ uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py
|
||||
uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py
|
||||
```
|
||||
|
||||
## Pre-commit Hooks
|
||||
|
@ -88,7 +88,7 @@ $ uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py
|
|||
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
||||
|
||||
```bash
|
||||
$ uv run pre-commit install
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
After that, pre-commit hooks will run automatically before each commit.
|
||||
|
@ -96,7 +96,7 @@ After that, pre-commit hooks will run automatically before each commit.
|
|||
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
|
||||
|
||||
```bash
|
||||
$ uv run pre-commit run --all-files
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
> [!CAUTION]
|
||||
|
@ -107,8 +107,8 @@ $ uv run pre-commit run --all-files
|
|||
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
||||
|
||||
```bash
|
||||
$ uv add foo
|
||||
$ uv sync
|
||||
uv add foo
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Coding Style
|
||||
|
@ -127,11 +127,11 @@ Building a stack image (conda / docker) will use the production version of the `
|
|||
|
||||
Example:
|
||||
```bash
|
||||
$ cd work/
|
||||
$ git clone https://github.com/meta-llama/llama-stack.git
|
||||
$ git clone https://github.com/meta-llama/llama-stack-client-python.git
|
||||
$ cd llama-stack
|
||||
$ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
|
||||
cd work/
|
||||
git clone https://github.com/meta-llama/llama-stack.git
|
||||
git clone https://github.com/meta-llama/llama-stack-client-python.git
|
||||
cd llama-stack
|
||||
LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
|
||||
```
|
||||
|
||||
|
||||
|
@ -144,14 +144,14 @@ If you have made changes to a provider's configuration in any form (introducing
|
|||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||
|
||||
```bash
|
||||
$ cd llama-stack/docs
|
||||
$ uv sync --extra docs
|
||||
cd llama-stack/docs
|
||||
uv sync --extra docs
|
||||
|
||||
# This rebuilds the documentation pages.
|
||||
$ uv run make html
|
||||
uv run make html
|
||||
|
||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||
$ uv run sphinx-autobuild source build/html --write-all
|
||||
uv run sphinx-autobuild source build/html --write-all
|
||||
```
|
||||
|
||||
### Update API Documentation
|
||||
|
@ -159,8 +159,7 @@ $ uv run sphinx-autobuild source build/html --write-all
|
|||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||
|
||||
```bash
|
||||
$ uv sync --extra dev
|
||||
$ uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
||||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
include pyproject.toml
|
||||
include distributions/dependencies.json
|
||||
include llama_stack/models/llama/llama3/tokenizer.model
|
||||
include llama_stack/distribution/*.sh
|
||||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/templates/*/*.yaml
|
||||
include llama_stack/providers/tests/test_cases/inference/*.json
|
||||
include llama_stack/models/llama/*/*.md
|
||||
|
|
|
@ -427,6 +427,7 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
@ -448,7 +449,6 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn"
|
||||
|
|
480
docs/_static/llama-stack-spec.html
vendored
480
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
326
docs/_static/llama-stack-spec.yaml
vendored
326
docs/_static/llama-stack-spec.yaml
vendored
|
@ -31,25 +31,32 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- DatasetIO
|
||||
description: ''
|
||||
description: >-
|
||||
Get a paginated list of rows from a dataset.
|
||||
parameters:
|
||||
- name: dataset_id
|
||||
in: query
|
||||
description: >-
|
||||
The ID of the dataset to get the rows from.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: rows_in_page
|
||||
in: query
|
||||
description: The number of rows to get per page.
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
- name: page_token
|
||||
in: query
|
||||
description: The token to get the next page of rows.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: filter_condition
|
||||
in: query
|
||||
description: >-
|
||||
(Optional) A condition to filter the rows by.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
|
@ -231,10 +238,33 @@ paths:
|
|||
$ref: '#/components/schemas/CompletionRequest'
|
||||
required: true
|
||||
/v1/agents:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all agents.
|
||||
parameters: []
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
An AgentCreateResponse with the agent ID.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -251,7 +281,8 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: >-
|
||||
Create an agent with the given configuration.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
|
@ -263,7 +294,7 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: An AgentSessionCreateResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -280,10 +311,12 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Create a new session for an agent.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to create the session for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -298,8 +331,8 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A single turn in an interaction with an Agentic System. **OR** streamed
|
||||
agent turn completion response.
|
||||
If stream=False, returns a Turn object. If stream=True, returns an SSE
|
||||
event stream of AgentTurnResponseStreamChunk
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -319,15 +352,19 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Create a new turn for an agent.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to create the turn for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: session_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the session to create the turn for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -395,6 +432,34 @@ paths:
|
|||
$ref: '#/components/schemas/CreateUploadSessionRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: An Agent of the agent.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: Describe an agent by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: ID of the agent.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -411,10 +476,11 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Delete an agent by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: The ID of the agent to delete.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -439,20 +505,25 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Retrieve an agent session by its ID.
|
||||
parameters:
|
||||
- name: session_id
|
||||
in: path
|
||||
description: The ID of the session to get.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to get the session for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: turn_ids
|
||||
in: query
|
||||
description: >-
|
||||
(Optional) List of turn IDs to filter the session by.
|
||||
required: false
|
||||
schema:
|
||||
type: array
|
||||
|
@ -474,15 +545,18 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Delete an agent session by its ID.
|
||||
parameters:
|
||||
- name: session_id
|
||||
in: path
|
||||
description: The ID of the session to delete.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to delete the session for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -596,7 +670,8 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
EvaluateResponse object containing generations and scores
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -613,10 +688,12 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Eval
|
||||
description: ''
|
||||
description: Evaluate a list of rows on a benchmark.
|
||||
parameters:
|
||||
- name: benchmark_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the benchmark to run the evaluation on.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -630,7 +707,7 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: An AgentStepResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -647,25 +724,30 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Retrieve an agent step by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: The ID of the agent to get the step for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: session_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the session to get the step for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: turn_id
|
||||
in: path
|
||||
description: The ID of the turn to get the step for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: step_id
|
||||
in: path
|
||||
description: The ID of the step to get.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -673,7 +755,7 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: A Turn.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -690,20 +772,24 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: Retrieve an agent turn by its ID.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: The ID of the agent to get the turn for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: session_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the session to get the turn for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: turn_id
|
||||
in: path
|
||||
description: The ID of the turn to get.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -1391,7 +1477,7 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: The status of the evaluationjob.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -1410,15 +1496,18 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Eval
|
||||
description: ''
|
||||
description: Get the status of a job.
|
||||
parameters:
|
||||
- name: benchmark_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the benchmark to run the evaluation on.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: job_id
|
||||
in: path
|
||||
description: The ID of the job to get the status of.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -1438,15 +1527,18 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Eval
|
||||
description: ''
|
||||
description: Cancel a job.
|
||||
parameters:
|
||||
- name: benchmark_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the benchmark to run the evaluation on.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: job_id
|
||||
in: path
|
||||
description: The ID of the job to cancel.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -1454,7 +1546,7 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: The result of the job.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -1471,15 +1563,48 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Eval
|
||||
description: ''
|
||||
description: Get the result of a job.
|
||||
parameters:
|
||||
- name: benchmark_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the benchmark to run the evaluation on.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: job_id
|
||||
in: path
|
||||
description: The ID of the job to get the result of.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/agents/{agent_id}/sessions:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListAgentSessionsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListAgentSessionsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all session(s) of a given agent.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the agent to list sessions for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -2192,7 +2317,8 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
The job that was created to run the evaluation.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -2209,10 +2335,12 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Eval
|
||||
description: ''
|
||||
description: Run an evaluation on a benchmark.
|
||||
parameters:
|
||||
- name: benchmark_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the benchmark to run the evaluation on.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -2280,7 +2408,8 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
ScoreResponse object containing rows and aggregated results
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -2297,7 +2426,7 @@ paths:
|
|||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Scoring
|
||||
description: ''
|
||||
description: Score a list of rows.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
|
@ -3567,6 +3696,7 @@ components:
|
|||
properties:
|
||||
agent_config:
|
||||
$ref: '#/components/schemas/AgentConfig'
|
||||
description: The configuration for the agent.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- agent_config
|
||||
|
@ -3585,6 +3715,7 @@ components:
|
|||
properties:
|
||||
session_name:
|
||||
type: string
|
||||
description: The name of the session to create.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- session_name
|
||||
|
@ -3607,8 +3738,12 @@ components:
|
|||
oneOf:
|
||||
- $ref: '#/components/schemas/UserMessage'
|
||||
- $ref: '#/components/schemas/ToolResponseMessage'
|
||||
description: List of messages to start the turn with.
|
||||
stream:
|
||||
type: boolean
|
||||
description: >-
|
||||
(Optional) If True, generate an SSE event stream of the response. Defaults
|
||||
to False.
|
||||
documents:
|
||||
type: array
|
||||
items:
|
||||
|
@ -3622,19 +3757,30 @@ components:
|
|||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
- $ref: '#/components/schemas/URL'
|
||||
description: The content of the document.
|
||||
mime_type:
|
||||
type: string
|
||||
description: The MIME type of the document.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- mime_type
|
||||
title: Document
|
||||
description: A document to be used by an agent.
|
||||
description: >-
|
||||
(Optional) List of documents to create the turn with.
|
||||
toolgroups:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/AgentTool'
|
||||
description: >-
|
||||
(Optional) List of toolgroups to create the turn with, will be used in
|
||||
addition to the agent's config toolgroups for the request.
|
||||
tool_config:
|
||||
$ref: '#/components/schemas/ToolConfig'
|
||||
description: >-
|
||||
(Optional) The tool configuration to create the turn with, will be used
|
||||
to override the agent's tool_config.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- messages
|
||||
|
@ -3644,20 +3790,25 @@ components:
|
|||
properties:
|
||||
turn_id:
|
||||
type: string
|
||||
description: The ID of the turn.
|
||||
step_id:
|
||||
type: string
|
||||
description: The ID of the step.
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step started.
|
||||
completed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step completed.
|
||||
step_type:
|
||||
type: string
|
||||
const: inference
|
||||
default: inference
|
||||
model_response:
|
||||
$ref: '#/components/schemas/CompletionMessage'
|
||||
description: The response from the LLM.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- turn_id
|
||||
|
@ -3665,27 +3816,36 @@ components:
|
|||
- step_type
|
||||
- model_response
|
||||
title: InferenceStep
|
||||
description: An inference step in an agent turn.
|
||||
MemoryRetrievalStep:
|
||||
type: object
|
||||
properties:
|
||||
turn_id:
|
||||
type: string
|
||||
description: The ID of the turn.
|
||||
step_id:
|
||||
type: string
|
||||
description: The ID of the step.
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step started.
|
||||
completed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step completed.
|
||||
step_type:
|
||||
type: string
|
||||
const: memory_retrieval
|
||||
default: memory_retrieval
|
||||
vector_db_ids:
|
||||
type: string
|
||||
description: >-
|
||||
The IDs of the vector databases to retrieve context from.
|
||||
inserted_context:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The context retrieved from the vector databases.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- turn_id
|
||||
|
@ -3694,6 +3854,8 @@ components:
|
|||
- vector_db_ids
|
||||
- inserted_context
|
||||
title: MemoryRetrievalStep
|
||||
description: >-
|
||||
A memory retrieval step in an agent turn.
|
||||
SafetyViolation:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -3721,39 +3883,49 @@ components:
|
|||
properties:
|
||||
turn_id:
|
||||
type: string
|
||||
description: The ID of the turn.
|
||||
step_id:
|
||||
type: string
|
||||
description: The ID of the step.
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step started.
|
||||
completed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step completed.
|
||||
step_type:
|
||||
type: string
|
||||
const: shield_call
|
||||
default: shield_call
|
||||
violation:
|
||||
$ref: '#/components/schemas/SafetyViolation'
|
||||
description: The violation from the shield call.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- turn_id
|
||||
- step_id
|
||||
- step_type
|
||||
title: ShieldCallStep
|
||||
description: A shield call step in an agent turn.
|
||||
ToolExecutionStep:
|
||||
type: object
|
||||
properties:
|
||||
turn_id:
|
||||
type: string
|
||||
description: The ID of the turn.
|
||||
step_id:
|
||||
type: string
|
||||
description: The ID of the step.
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step started.
|
||||
completed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: The time the step completed.
|
||||
step_type:
|
||||
type: string
|
||||
const: tool_execution
|
||||
|
@ -3762,10 +3934,12 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolCall'
|
||||
description: The tool calls to execute.
|
||||
tool_responses:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponse'
|
||||
description: The tool responses from the tool calls.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- turn_id
|
||||
|
@ -3774,6 +3948,7 @@ components:
|
|||
- tool_calls
|
||||
- tool_responses
|
||||
title: ToolExecutionStep
|
||||
description: A tool execution step in an agent turn.
|
||||
ToolResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -3850,13 +4025,16 @@ components:
|
|||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
- $ref: '#/components/schemas/URL'
|
||||
description: The content of the attachment.
|
||||
mime_type:
|
||||
type: string
|
||||
description: The MIME type of the attachment.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- mime_type
|
||||
title: Attachment
|
||||
description: An attachment to an agent turn.
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
@ -3922,6 +4100,7 @@ components:
|
|||
- shield_call
|
||||
- memory_retrieval
|
||||
title: StepType
|
||||
description: Type of the step in an agent turn.
|
||||
step_id:
|
||||
type: string
|
||||
step_details:
|
||||
|
@ -3959,6 +4138,7 @@ components:
|
|||
- shield_call
|
||||
- memory_retrieval
|
||||
title: StepType
|
||||
description: Type of the step in an agent turn.
|
||||
step_id:
|
||||
type: string
|
||||
delta:
|
||||
|
@ -3985,6 +4165,7 @@ components:
|
|||
- shield_call
|
||||
- memory_retrieval
|
||||
title: StepType
|
||||
description: Type of the step in an agent turn.
|
||||
step_id:
|
||||
type: string
|
||||
metadata:
|
||||
|
@ -4212,11 +4393,14 @@ components:
|
|||
default: agent
|
||||
config:
|
||||
$ref: '#/components/schemas/AgentConfig'
|
||||
description: >-
|
||||
The configuration for the agent candidate.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- config
|
||||
title: AgentCandidate
|
||||
description: An agent candidate for evaluation.
|
||||
AggregationFunctionType:
|
||||
type: string
|
||||
enum:
|
||||
|
@ -4245,17 +4429,26 @@ components:
|
|||
properties:
|
||||
eval_candidate:
|
||||
$ref: '#/components/schemas/EvalCandidate'
|
||||
description: The candidate to evaluate.
|
||||
scoring_params:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ScoringFnParams'
|
||||
description: >-
|
||||
Map between scoring function id and parameters for each scoring function
|
||||
you want to run
|
||||
num_examples:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) The number of examples to evaluate. If not provided, all examples
|
||||
in the dataset will be evaluated
|
||||
additionalProperties: false
|
||||
required:
|
||||
- eval_candidate
|
||||
- scoring_params
|
||||
title: BenchmarkConfig
|
||||
description: >-
|
||||
A benchmark configuration for evaluation.
|
||||
EvalCandidate:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ModelCandidate'
|
||||
|
@ -4298,16 +4491,22 @@ components:
|
|||
default: model
|
||||
model:
|
||||
type: string
|
||||
description: The model ID to evaluate.
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
description: The sampling parameters for the model.
|
||||
system_message:
|
||||
$ref: '#/components/schemas/SystemMessage'
|
||||
description: >-
|
||||
(Optional) The system message providing instructions or context to the
|
||||
model.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- model
|
||||
- sampling_params
|
||||
title: ModelCandidate
|
||||
description: A model candidate for evaluation.
|
||||
RegexParserScoringFnParams:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -4353,12 +4552,16 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The rows to evaluate.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the evaluation.
|
||||
benchmark_config:
|
||||
$ref: '#/components/schemas/BenchmarkConfig'
|
||||
description: The configuration for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_rows
|
||||
|
@ -4380,15 +4583,18 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The generations from the evaluation.
|
||||
scores:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ScoringResult'
|
||||
description: The scores from the evaluation.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- generations
|
||||
- scores
|
||||
title: EvaluateResponse
|
||||
description: The response from an evaluation.
|
||||
ScoringResult:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -4404,6 +4610,8 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The scoring result for each row. Each row is a map of column name to value.
|
||||
aggregated_results:
|
||||
type: object
|
||||
additionalProperties:
|
||||
|
@ -4414,11 +4622,29 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Map of metric name to aggregated value
|
||||
additionalProperties: false
|
||||
required:
|
||||
- score_rows
|
||||
- aggregated_results
|
||||
title: ScoringResult
|
||||
description: A scoring result for a single row.
|
||||
Agent:
|
||||
type: object
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
agent_config:
|
||||
$ref: '#/components/schemas/AgentConfig'
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
additionalProperties: false
|
||||
required:
|
||||
- agent_id
|
||||
- agent_config
|
||||
- created_at
|
||||
title: Agent
|
||||
Session:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -4731,15 +4957,19 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The rows in the current page.
|
||||
total_count:
|
||||
type: integer
|
||||
description: The total number of rows in the dataset.
|
||||
next_page_token:
|
||||
type: string
|
||||
description: The token to get the next page of rows.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- rows
|
||||
- total_count
|
||||
title: PaginatedRowsResult
|
||||
description: A paginated list of rows from a dataset.
|
||||
ScoringFn:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5251,6 +5481,28 @@ components:
|
|||
required:
|
||||
- content
|
||||
title: ToolInvocationResult
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Session'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentSessionsResponse
|
||||
ListAgentsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Agent'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListAgentsResponse
|
||||
BucketResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6153,11 +6405,16 @@ components:
|
|||
type: object
|
||||
properties:
|
||||
tool_responses:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponseMessage'
|
||||
oneOf:
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponse'
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponseMessage'
|
||||
description: >-
|
||||
The tool call responses to resume the turn with.
|
||||
The tool call responses to resume the turn with. NOTE: ToolResponseMessage
|
||||
will be deprecated. Use ToolResponse.
|
||||
stream:
|
||||
type: boolean
|
||||
description: Whether to stream the response.
|
||||
|
@ -6170,6 +6427,7 @@ components:
|
|||
properties:
|
||||
benchmark_config:
|
||||
$ref: '#/components/schemas/BenchmarkConfig'
|
||||
description: The configuration for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_config
|
||||
|
@ -6251,12 +6509,15 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The rows to score.
|
||||
scoring_functions:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ScoringFnParams'
|
||||
- type: 'null'
|
||||
description: >-
|
||||
The scoring functions to use for the scoring.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_rows
|
||||
|
@ -6269,10 +6530,13 @@ components:
|
|||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ScoringResult'
|
||||
description: >-
|
||||
A map of scoring function name to ScoringResult.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- results
|
||||
title: ScoreResponse
|
||||
description: The response from scoring.
|
||||
ScoreBatchRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6543,6 +6807,8 @@ tags:
|
|||
- name: DatasetIO
|
||||
- name: Datasets
|
||||
- name: Eval
|
||||
x-displayName: >-
|
||||
Llama Stack Evaluation API for running evaluations on model and agent candidates.
|
||||
- name: Files (Coming Soon)
|
||||
- name: Inference
|
||||
description: >-
|
||||
|
|
|
@ -141,7 +141,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 18,
|
||||
"id": "E1UFuJC570Tk",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
|
@ -326,54 +326,108 @@
|
|||
" type: sqlite\n",
|
||||
"models:\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-FP8\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
|
||||
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
|
@ -473,6 +527,9 @@
|
|||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" provider_id: model-context-protocol\n",
|
||||
" provider_type: remote::model-context-protocol\n",
|
||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" provider_type: remote::wolfram-alpha\n",
|
||||
" vector_io:\n",
|
||||
" - config:\n",
|
||||
" kvstore:\n",
|
||||
|
@ -504,6 +561,10 @@
|
|||
" mcp_endpoint: null\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" toolgroup_id: builtin::code_interpreter\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"version: <span style=\"color: #008000; text-decoration-color: #008000\">'2'</span>\n",
|
||||
"\n",
|
||||
|
@ -530,54 +591,108 @@
|
|||
" type: sqlite\n",
|
||||
"models:\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
" provider_id: together\n",
|
||||
" provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
|
||||
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n",
|
||||
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
|
||||
" - llm\n",
|
||||
|
@ -677,6 +792,9 @@
|
|||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" provider_id: model-context-protocol\n",
|
||||
" provider_type: remote::model-context-protocol\n",
|
||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" provider_type: remote::wolfram-alpha\n",
|
||||
" vector_io:\n",
|
||||
" - config:\n",
|
||||
" kvstore:\n",
|
||||
|
@ -708,6 +826,10 @@
|
|||
" mcp_endpoint: null\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" toolgroup_id: builtin::code_interpreter\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"version: \u001b[32m'2'\u001b[0m\n",
|
||||
"\n"
|
||||
|
@ -1145,7 +1267,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -1157,7 +1278,7 @@
|
|||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" model_id=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -1513,18 +1634,14 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\"builtin::websearch\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" instructions=\"You are a helpful assistant. Use websearch tool to help answer questions.\",\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Hello\",\n",
|
||||
" \"Which teams played in the NBA western conference finals of 2024\",\n",
|
||||
|
@ -1693,7 +1810,6 @@
|
|||
"import uuid\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"from llama_stack_client.types import Document\n",
|
||||
"\n",
|
||||
|
@ -1719,11 +1835,11 @@
|
|||
" vector_db_id=vector_db_id,\n",
|
||||
" chunk_size_in_tokens=512,\n",
|
||||
")\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"rag_agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" toolgroups = [\n",
|
||||
" tools = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||
" \"args\" : {\n",
|
||||
|
@ -1732,7 +1848,6 @@
|
|||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"rag_agent = Agent(client, agent_config)\n",
|
||||
"session_id = rag_agent.create_session(\"test-session\")\n",
|
||||
"user_prompts = [\n",
|
||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
|
||||
|
@ -1856,23 +1971,19 @@
|
|||
"source": [
|
||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"codex_agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" tools=[\n",
|
||||
" \"builtin::code_interpreter\",\n",
|
||||
" \"builtin::websearch\"\n",
|
||||
" ],\n",
|
||||
" sampling_params = {\n",
|
||||
" \"max_tokens\" : 4096,\n",
|
||||
" \"temperature\": 0.0\n",
|
||||
" },\n",
|
||||
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\n",
|
||||
" \"builtin::code_interpreter\",\n",
|
||||
" \"builtin::websearch\"\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
")\n",
|
||||
"codex_agent = Agent(client, agent_config)\n",
|
||||
"session_id = codex_agent.create_session(\"test-session\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -2782,18 +2893,14 @@
|
|||
"# NBVAL_SKIP\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" toolgroups=[\"mcp::filesystem\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" tools=[\"mcp::filesystem\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Hello\",\n",
|
||||
" \"list all the files /content\",\n",
|
||||
|
@ -2888,17 +2995,13 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
|
||||
" toolgroups=[\"builtin::websearch\"],\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" tools=[\"builtin::websearch\"],\n",
|
||||
")\n",
|
||||
"agent = Agent(client, agent_config)\n",
|
||||
"user_prompts = [\n",
|
||||
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
|
||||
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
|
||||
|
@ -4098,7 +4201,7 @@
|
|||
"source": [
|
||||
"## 4. Image Understanding with Llama 3.2\n",
|
||||
"\n",
|
||||
"Below is a complete example of using Together's Llama Stack 0.1 server at https://llama-stack.together.ai to ask Llama 3.2 questions about an image."
|
||||
"Below is a complete example of to ask Llama 3.2 questions about an image."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -4106,14 +4209,12 @@
|
|||
"id": "82e381ec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 4.1 Setup and helpers\n",
|
||||
"\n",
|
||||
"Below we install the Llama Stack client 0.1, download the example image, define two image helpers, and set Llama Stack Together server URL and Llama 3.2 model name.\n"
|
||||
"### 4.1 Setup and helpers\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"id": "44e05e16",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
@ -4123,7 +4224,7 @@
|
|||
"text": [
|
||||
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||||
" Dload Upload Total Spent Left Speed\n",
|
||||
"100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n"
|
||||
"100 275k 100 275k 0 0 905k 0 --:--:-- --:--:-- --:--:-- 906k\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -4133,32 +4234,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "469750f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"def display_image(path):\n",
|
||||
" img = Image.open(path)\n",
|
||||
" plt.imshow(img)\n",
|
||||
" plt.axis('off')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"display_image(\"Llama_Repo.jpeg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 20,
|
||||
"id": "a2c1e1c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import base64\n",
|
||||
"vision_model_id = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
|
||||
"\n",
|
||||
"def encode_image(image_path):\n",
|
||||
" with open(image_path, \"rb\") as image_file:\n",
|
||||
|
@ -4167,19 +4249,6 @@
|
|||
" return base64_url"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c565f99e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"\n",
|
||||
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n",
|
||||
"LLAMA32_11B_INSTRUCT = \"meta-llama/Llama-3.2-11B-Vision-Instruct\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7737cd41",
|
||||
|
@ -4192,55 +4261,44 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 21,
|
||||
"id": "d7914894",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"There are three llamas in the image. The llama in the middle is purple, the llama on the left is white, and the llama on the right is also white, but it is wearing a blue party hat. Therefore, there are two different colors of llama in the image: purple and white.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"\n",
|
||||
"async def run_main(image_path: str, prompt):\n",
|
||||
" client = LlamaStackClient(\n",
|
||||
" base_url=LLAMA_STACK_API_TOGETHER_URL,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": {\n",
|
||||
" \"url\": {\n",
|
||||
" \"uri\": encode_image(image_path)\n",
|
||||
" }\n",
|
||||
"response = client.inference.chat_completion(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": {\n",
|
||||
" \"url\": {\n",
|
||||
" \"uri\": encode_image(\"Llama_Repo.jpeg\")\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": \"How many different colors are those llamas? What are those colors?\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": prompt,\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" model_id=vision_model_id,\n",
|
||||
" stream=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
" response = client.inference.chat_completion(\n",
|
||||
" messages=[message],\n",
|
||||
" model_id=LLAMA32_11B_INSTRUCT,\n",
|
||||
" stream=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(response.completion_message.content.lower().strip())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4ee09b97",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await run_main(\"Llama_Repo.jpeg\",\n",
|
||||
" \"How many different colors are those llamas?\\\n",
|
||||
" What are those colors?\")"
|
||||
"print(response.completion_message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -4255,68 +4313,60 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"id": "f9a83275",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33minference> \u001b[0m\u001b[33mThere\u001b[0m\u001b[33m are\u001b[0m\u001b[33m three\u001b[0m\u001b[33m different\u001b[0m\u001b[33m colors\u001b[0m\u001b[33m of\u001b[0m\u001b[33m ll\u001b[0m\u001b[33mamas\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m image\u001b[0m\u001b[33m.\u001b[0m\u001b[33m The\u001b[0m\u001b[33m first\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m on\u001b[0m\u001b[33m the\u001b[0m\u001b[33m left\u001b[0m\u001b[33m is\u001b[0m\u001b[33m white\u001b[0m\u001b[33m,\u001b[0m\u001b[33m the\u001b[0m\u001b[33m second\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m middle\u001b[0m\u001b[33m is\u001b[0m\u001b[33m purple\u001b[0m\u001b[33m,\u001b[0m\u001b[33m and\u001b[0m\u001b[33m the\u001b[0m\u001b[33m third\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m on\u001b[0m\u001b[33m the\u001b[0m\u001b[33m right\u001b[0m\u001b[33m is\u001b[0m\u001b[33m white\u001b[0m\u001b[33m with\u001b[0m\u001b[33m a\u001b[0m\u001b[33m blue\u001b[0m\u001b[33m party\u001b[0m\u001b[33m hat\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n",
|
||||
"\u001b[30m\u001b[0m"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=vision_model_id,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
")\n",
|
||||
"session_id = agent.create_session(\"test-session\")\n",
|
||||
"\n",
|
||||
"async def run_main(image_path, prompt):\n",
|
||||
" base64_image = encode_image(image_path)\n",
|
||||
"\n",
|
||||
" client = LlamaStackClient(\n",
|
||||
" base_url=LLAMA_STACK_API_TOGETHER_URL,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=LLAMA32_11B_INSTRUCT,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" toolgroups=[],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client, agent_config)\n",
|
||||
" session_id = agent.create_session(\"test-session\")\n",
|
||||
"\n",
|
||||
" response = agent.create_turn(\n",
|
||||
" messages=[{\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": {\n",
|
||||
" \"url\": {\n",
|
||||
" \"uri\": encode_image(image_path)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": prompt,\n",
|
||||
"response = agent.create_turn(\n",
|
||||
" messages=[{\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": {\n",
|
||||
" \"url\": {\n",
|
||||
" \"uri\": encode_image(\"Llama_Repo.jpeg\")\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": \"How many different colors are those llamas? What are those colors?\",\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }],\n",
|
||||
" session_id=session_id,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
"for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15d0098b",
|
||||
"id": "f3352379",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await run_main(\"Llama_Repo.jpeg\",\n",
|
||||
" \"How many different colors are those llamas?\\\n",
|
||||
" What are those colors?\")"
|
||||
]
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
3535
docs/notebooks/Llama_Stack_Agent_Workflows.ipynb
Normal file
3535
docs/notebooks/Llama_Stack_Agent_Workflows.ipynb
Normal file
File diff suppressed because it is too large
Load diff
|
@ -826,10 +826,9 @@
|
|||
"_ = client.datasets.register(\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" provider_id=\"huggingface\",\n",
|
||||
" url={\"uri\": \"https://huggingface.co/datasets/llamastack/evals\"},\n",
|
||||
" url={\"uri\": \"https://huggingface.co/datasets/llamastack/simpleqa\"},\n",
|
||||
" metadata={\n",
|
||||
" \"path\": \"llamastack/evals\",\n",
|
||||
" \"name\": \"evals__simpleqa\",\n",
|
||||
" \"path\": \"llamastack/simpleqa\",\n",
|
||||
" \"split\": \"train\",\n",
|
||||
" },\n",
|
||||
" dataset_schema={\n",
|
||||
|
|
1427
docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb
Normal file
1427
docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,9 +1 @@
|
|||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
|
||||
|
||||
Please install the following packages before running the script:
|
||||
|
||||
```
|
||||
pip install fire PyYAML
|
||||
```
|
||||
|
||||
Then simply run `sh run_openapi_generator.sh`
|
||||
|
|
|
@ -14,18 +14,16 @@ Agents are configured using the `AgentConfig` class, which includes:
|
|||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
||||
# Configure an agent
|
||||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3-70b-chat",
|
||||
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||
)
|
||||
|
||||
# Create the agent
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model="meta-llama/Llama-3-70b-chat",
|
||||
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Sessions
|
||||
|
|
|
@ -70,18 +70,18 @@ Each step in this process can be monitored and controlled through configurations
|
|||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from rich.pretty import pprint
|
||||
|
||||
# Replace host and port
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
client,
|
||||
# Check with `llama-stack-client models list`
|
||||
model="Llama3.2-3B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
# Enable both RAG and tool usage
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": ["my_docs"]},
|
||||
|
@ -98,8 +98,6 @@ agent_config = AgentConfig(
|
|||
"max_tokens": 2048,
|
||||
},
|
||||
)
|
||||
|
||||
agent = Agent(client, agent_config)
|
||||
session_id = agent.create_session("monitored_session")
|
||||
|
||||
# Stream the agent's execution steps
|
||||
|
|
|
@ -1,169 +1,127 @@
|
|||
# Evals
|
||||
# Evaluations
|
||||
|
||||
[](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
|
||||
The Llama Stack provides a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||
- `/datasetio` + `/datasets` API
|
||||
- `/scoring` + `/scoring_functions` API
|
||||
- `/eval` + `/benchmarks` API
|
||||
|
||||
Llama Stack provides the building blocks needed to run benchmark and application evaluations. This guide will walk you through how to use these components to run open benchmark evaluations. Visit our [Evaluation Concepts](../concepts/evaluation_concepts.md) guide for more details on how evaluations work in Llama Stack, and our [Evaluation Reference](../references/evals_reference/index.md) guide for a comprehensive reference on the APIs.
|
||||
|
||||
### 1. Open Benchmark Model Evaluation
|
||||
|
||||
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark:
|
||||
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI): Benchmark designed to evaluate multimodal models.
|
||||
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||
This guides walks you through the process of evaluating an LLM application built using Llama Stack. Checkout the [Evaluation Reference](../references/evals_reference/index.md) guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for benchmark and application use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||
|
||||
#### 1.1 Running MMMU
|
||||
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
|
||||
|
||||
## Application Evaluation
|
||||
|
||||
[](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb)
|
||||
|
||||
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||
|
||||
In this example, we will show you how to:
|
||||
1. Build an Agent with Llama Stack
|
||||
2. Query the agent's sessions, turns, and steps
|
||||
3. Evaluate the results.
|
||||
|
||||
##### Building a Search Agent
|
||||
```python
|
||||
import datasets
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
|
||||
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
||||
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
||||
eval_rows = ds.to_pandas().to_dict(orient="records")
|
||||
```
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
- Next, we will run evaluation on an model candidate, we will need to:
|
||||
- Define a system prompt
|
||||
- Define an EvalCandidate
|
||||
- Run evaluate on the dataset
|
||||
|
||||
```python
|
||||
SYSTEM_PROMPT_TEMPLATE = """
|
||||
You are an expert in Agriculture whose job is to answer questions from the user using images.
|
||||
First, reason about the correct answer.
|
||||
Then write the answer in the following format where X is exactly one of A,B,C,D:
|
||||
Answer: X
|
||||
Make sure X is one of A,B,C,D.
|
||||
If you are uncertain of the correct answer, guess the most likely one.
|
||||
"""
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
dataset_id=f"mmmu-{subset}-{split}",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
instructions="You are a helpful assistant. Use search tool to answer the questions. ",
|
||||
tools=["builtin::websearch"],
|
||||
)
|
||||
user_prompts = [
|
||||
"Which teams played in the NBA Western Conference Finals of 2024. Search the web for the answer.",
|
||||
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
|
||||
"What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.",
|
||||
]
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
input_rows=eval_rows,
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
"system_message": system_message,
|
||||
},
|
||||
},
|
||||
)
|
||||
```
|
||||
session_id = agent.create_session("test-session")
|
||||
|
||||
#### 1.2. Running SimpleQA
|
||||
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API.
|
||||
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API.
|
||||
for prompt in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
```python
|
||||
simpleqa_dataset_id = "huggingface::simpleqa"
|
||||
|
||||
_ = client.datasets.register(
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
|
||||
metadata={
|
||||
"path": "llamastack/evals",
|
||||
"name": "evals__simpleqa",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "chat_completion_input"},
|
||||
},
|
||||
)
|
||||
|
||||
eval_rows = client.datasetio.get_rows_paginated(
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
rows_in_page=5,
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
||||
### 2. Agentic Evaluation
|
||||
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
|
||||
- We will continue to use the SimpleQA dataset we used in previous example.
|
||||
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`.
|
||||
##### Query Agent Execution Steps
|
||||
|
||||
Now, let's look deeper into the agent's execution steps and see if how well our agent performs.
|
||||
```python
|
||||
# query the agents session
|
||||
from rich.pretty import pprint
|
||||
|
||||
session_response = client.agents.session.retrieve(
|
||||
session_id=session_id,
|
||||
agent_id=agent.agent_id,
|
||||
)
|
||||
|
||||
pprint(session_response)
|
||||
```
|
||||
|
||||
As a sanity check, we will first check if all user prompts is followed by a tool call to `brave_search`.
|
||||
```python
|
||||
num_tool_call = 0
|
||||
for turn in session_response.turns:
|
||||
for step in turn.steps:
|
||||
if (
|
||||
step.step_type == "tool_execution"
|
||||
and step.tool_calls[0].tool_name == "brave_search"
|
||||
):
|
||||
num_tool_call += 1
|
||||
|
||||
print(
|
||||
f"{num_tool_call}/{len(session_response.turns)} user prompts are followed by a tool call to `brave_search`"
|
||||
)
|
||||
```
|
||||
|
||||
##### Evaluate Agent Responses
|
||||
Now, we want to evaluate the agent's responses to the user prompts.
|
||||
|
||||
1. First, we will process the agent's execution history into a list of rows that can be used for evaluation.
|
||||
2. Next, we will label the rows with the expected answer.
|
||||
3. Finally, we will use the `/scoring` API to score the agent's responses.
|
||||
|
||||
```python
|
||||
agent_config = {
|
||||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"instructions": "You are a helpful assistant",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
eval_rows = []
|
||||
|
||||
expected_answers = [
|
||||
"Dallas Mavericks and the Minnesota Timberwolves",
|
||||
"Season 4, Episode 12",
|
||||
"King Cobra",
|
||||
]
|
||||
|
||||
for i, turn in enumerate(session_response.turns):
|
||||
eval_rows.append(
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "tavily",
|
||||
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
|
||||
"input_query": turn.input_messages[0].content,
|
||||
"generated_answer": turn.output_message.content,
|
||||
"expected_answer": expected_answers[i],
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto",
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"enable_session_persistence": False,
|
||||
}
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "agent",
|
||||
"config": agent_config,
|
||||
},
|
||||
},
|
||||
pprint(eval_rows)
|
||||
|
||||
scoring_params = {
|
||||
"basic::subset_of": None,
|
||||
}
|
||||
scoring_response = client.scoring.score(
|
||||
input_rows=eval_rows, scoring_functions=scoring_params
|
||||
)
|
||||
pprint(scoring_response)
|
||||
```
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
## Testing & Evaluation
|
||||
|
||||
Llama Stack provides built-in tools for evaluating your applications:
|
||||
|
||||
1. **Benchmarking**: Test against standard datasets
|
||||
2. **Application Evaluation**: Score your application's outputs
|
||||
3. **Custom Metrics**: Define your own evaluation criteria
|
||||
|
||||
Here's how to set up basic evaluation:
|
||||
|
||||
```python
|
||||
# Create an evaluation task
|
||||
response = client.benchmarks.register(
|
||||
benchmark_id="my_eval",
|
||||
dataset_id="my_dataset",
|
||||
scoring_functions=["accuracy", "relevance"],
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
job = client.eval.run_eval(
|
||||
benchmark_id="my_eval",
|
||||
benchmark_config={
|
||||
"type": "app",
|
||||
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||
},
|
||||
)
|
||||
|
||||
# Get results
|
||||
result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id)
|
||||
```
|
|
@ -20,6 +20,11 @@ We may add more storage types like Graph IO in the future.
|
|||
Here's how to set up a vector database for RAG:
|
||||
|
||||
```python
|
||||
# Create http client
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
|
||||
|
||||
# Register a vector db
|
||||
vector_db_id = "my_documents"
|
||||
response = client.vector_dbs.register(
|
||||
|
@ -81,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
|
|||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
||||
# Configure agent with memory
|
||||
agent_config = AgentConfig(
|
||||
# Create agent with memory
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
enable_session_persistence=False,
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
|
@ -98,8 +102,6 @@ agent_config = AgentConfig(
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
agent = Agent(client, agent_config)
|
||||
session_id = agent.create_session("rag_session")
|
||||
|
||||
|
||||
|
@ -136,6 +138,14 @@ response = agent.create_turn(
|
|||
)
|
||||
```
|
||||
|
||||
You can print the response with below.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
### Unregistering Vector DBs
|
||||
|
||||
If you need to clean up and unregister vector databases, you can do so as follows:
|
||||
|
|
|
@ -5,7 +5,7 @@ An example of this would be a "db_access" tool group that contains tools for int
|
|||
|
||||
Tools are treated as any other resource in llama stack like models. You can register them, have providers for them etc.
|
||||
|
||||
When instatiating an agent, you can provide it a list of tool groups that it has access to. Agent gets the corresponding tool definitions for the specified tool groups and passes them along to the model.
|
||||
When instantiating an agent, you can provide it a list of tool groups that it has access to. Agent gets the corresponding tool definitions for the specified tool groups and passes them along to the model.
|
||||
|
||||
Refer to the [Building AI Applications](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) notebook for more examples on how to use tools.
|
||||
|
||||
|
@ -60,7 +60,7 @@ Features:
|
|||
- Disabled dangerous system operations
|
||||
- Configurable execution timeouts
|
||||
|
||||
> ⚠️ Important: The code interpreter tool can operate in a controlled enviroment locally or on Podman containers. To ensure proper functionality in containerised environments:
|
||||
> ⚠️ Important: The code interpreter tool can operate in a controlled environment locally or on Podman containers. To ensure proper functionality in containerized environments:
|
||||
> - The container requires privileged access (e.g., --privileged).
|
||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
||||
|
@ -127,15 +127,11 @@ MCP tools require:
|
|||
|
||||
## Adding Custom Tools
|
||||
|
||||
When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
|
||||
When you want to use tools other than the built-in tools, you just need to implement a python function with a docstring. The content of the docstring will be used to describe the tool and the parameters and passed
|
||||
along to the generative model.
|
||||
|
||||
To define a custom tool, you need to use the `@client_tool` decorator.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.client_tool import client_tool
|
||||
|
||||
|
||||
# Example tool definition
|
||||
@client_tool
|
||||
def my_tool(input: int) -> int:
|
||||
"""
|
||||
Runs my awesome tool.
|
||||
|
@ -149,15 +145,7 @@ def my_tool(input: int) -> int:
|
|||
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
||||
```python
|
||||
# Example agent config with client provided tools
|
||||
client_tools = [
|
||||
my_tool,
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
...,
|
||||
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
|
||||
)
|
||||
agent = Agent(client, agent_config, client_tools)
|
||||
agent = Agent(client, ..., tools=[my_tool])
|
||||
```
|
||||
|
||||
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
||||
|
@ -194,10 +182,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
|||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
# Configure the AI agent with necessary parameters
|
||||
agent_config = AgentConfig(
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(
|
||||
client,
|
||||
name="code-interpreter",
|
||||
description="A code interpreter agent for executing Python code snippets",
|
||||
instructions="""
|
||||
|
@ -205,14 +193,10 @@ agent_config = AgentConfig(
|
|||
Always show the generated code, never generate your own code, and never anticipate results.
|
||||
""",
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
toolgroups=["builtin::code_interpreter"],
|
||||
tools=["builtin::code_interpreter"],
|
||||
max_infer_iters=5,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(client, agent_config)
|
||||
|
||||
# Start a session
|
||||
session_id = agent.create_session("tool_session")
|
||||
|
||||
|
|
|
@ -24,17 +24,58 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
|||
- Associated with `Benchmark` resource.
|
||||
|
||||
|
||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||

|
||||
## Open-benchmark Eval
|
||||
|
||||
### List of open-benchmarks Llama Stack support
|
||||
|
||||
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
|
||||
|
||||
The list of open-benchmarks we currently support:
|
||||
- [MMLU-COT](https://arxiv.org/abs/2009.03300) (Measuring Massive Multitask Language Understanding): Benchmark designed to comprehensively evaluate the breadth and depth of a model's academic and professional understanding
|
||||
- [GPQA-COT](https://arxiv.org/abs/2311.12022) (A Graduate-Level Google-Proof Q&A Benchmark): A challenging benchmark of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry.
|
||||
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||
|
||||
|
||||
```{admonition} Note on Benchmark v.s. Application Evaluation
|
||||
:class: tip
|
||||
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
|
||||
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
|
||||
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||
|
||||
### Run evaluation on open-benchmarks via CLI
|
||||
|
||||
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
|
||||
|
||||
#### Spin up Llama Stack server
|
||||
|
||||
Spin up llama stack server with 'open-benchmark' template
|
||||
```
|
||||
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||
|
||||
```
|
||||
|
||||
#### Run eval CLI
|
||||
There are 3 necessary inputs to run a benchmark eval
|
||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||
- `model-id`: The model id to evaluate on
|
||||
- `utput_dir`: Path to store the evaluate results
|
||||
```
|
||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
||||
You can run
|
||||
```
|
||||
llama-stack-client eval run-benchmark help
|
||||
```
|
||||
to see the description of all the flags that eval run-benchmark has
|
||||
|
||||
|
||||
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
|
||||
evaluation results over there.
|
||||
|
||||
|
||||
|
||||
## What's Next?
|
||||
|
||||
- Check out our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
||||
- Check out our [Building Applications - Evaluation](../building_applications/evals.md) guide for more details on how to use the Evaluation APIs to evaluate your applications.
|
||||
- Check out our [Evaluation Reference](../references/evals_reference/index.md) for more details on the APIs.
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
# Core Concepts
|
||||
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
|
||||
evaluation_concepts
|
||||
```
|
||||
|
||||
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
|
||||
|
||||
|
||||
|
@ -26,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
|||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
Providers come in two flavors:
|
||||
|
|
|
@ -13,16 +13,18 @@
|
|||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
from docutils import nodes
|
||||
import tomli # Import tomli for TOML parsing
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Read version from pyproject.toml
|
||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||
pyproject = tomli.load(f)
|
||||
llama_stack_version = pyproject["project"]["version"]
|
||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
||||
print(f"{version_tag=}")
|
||||
|
||||
# generate the full link including text and url here
|
||||
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{llama_stack_version}"
|
||||
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
|
||||
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
|
||||
|
||||
project = "llama-stack"
|
||||
|
@ -77,7 +79,7 @@ myst_enable_extensions = [
|
|||
|
||||
myst_substitutions = {
|
||||
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
||||
"llama_stack_version": llama_stack_version,
|
||||
"llama_stack_version": version_tag,
|
||||
"llama_stack_version_link": llama_stack_version_link,
|
||||
}
|
||||
|
||||
|
|
|
@ -17,25 +17,31 @@ Here are some example PRs to help you get started:
|
|||
|
||||
## Testing the Provider
|
||||
|
||||
Before running tests, you must have required dependencies installed. This depends on the providers or distributions you are testing. For example, if you are testing the `together` distribution, you should install dependencies via `llama stack build --template together`.
|
||||
|
||||
### 1. Integration Testing
|
||||
- Create integration tests that use real provider instances and configurations
|
||||
- For remote services, test actual API interactions
|
||||
- Avoid mocking at the provider level since adapter layers tend to be thin
|
||||
- Reference examples in {repopath}`tests/api`
|
||||
|
||||
### 2. Unit Testing (Optional)
|
||||
- Add unit tests for provider-specific functionality
|
||||
- See examples in {repopath}`llama_stack/providers/tests/inference/test_text_inference.py`
|
||||
Integration tests are located in {repopath}`tests/integration`. These tests use the python client-SDK APIs (from the `llama_stack_client` package) to test functionality. Since these tests use client APIs, they can be run either by pointing to an instance of the Llama Stack server or "inline" by using `LlamaStackAsLibraryClient`.
|
||||
|
||||
Consult {repopath}`tests/integration/README.md` for more details on how to run the tests.
|
||||
|
||||
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
|
||||
typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command.
|
||||
|
||||
|
||||
### 2. Unit Testing
|
||||
|
||||
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
||||
|
||||
|
||||
### 3. Additional end-to-end testing
|
||||
|
||||
### 3. End-to-End Testing
|
||||
1. Start a Llama Stack server with your new provider
|
||||
2. Test using client requests
|
||||
3. Verify compatibility with existing client scripts in the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main) repository
|
||||
4. Document which scripts are compatible with your provider
|
||||
2. Verify compatibility with existing client scripts in the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main) repository
|
||||
3. Document which scripts are compatible with your provider
|
||||
|
||||
## Submitting Your PR
|
||||
|
||||
1. Ensure all tests pass
|
||||
2. Include a comprehensive test plan in your PR summary
|
||||
3. Document any known limitations or considerations
|
||||
4. Submit your pull request for review
|
||||
|
|
|
@ -4,6 +4,37 @@
|
|||
This guide will walk you through the steps to get started with building a Llama Stack distribution from scratch with your choice of API providers.
|
||||
|
||||
|
||||
### Setting your log level
|
||||
|
||||
In order to specify the proper logging level users can apply the following environment variable `LLAMA_STACK_LOGGING` with the following format:
|
||||
|
||||
`LLAMA_STACK_LOGGING=server=debug;core=info`
|
||||
|
||||
Where each category in the following list:
|
||||
|
||||
- all
|
||||
- core
|
||||
- server
|
||||
- router
|
||||
- inference
|
||||
- agents
|
||||
- safety
|
||||
- eval
|
||||
- tools
|
||||
- client
|
||||
|
||||
Can be set to any of the following log levels:
|
||||
|
||||
- debug
|
||||
- info
|
||||
- warning
|
||||
- error
|
||||
- critical
|
||||
|
||||
The default global log level is `info`. `all` sets the log level for all components.
|
||||
|
||||
A user can also set `LLAMA_STACK_LOG_FILE` which will pipe the logs to the specified path as well as to the terminal. An example would be: `export LLAMA_STACK_LOG_FILE=server.log`
|
||||
|
||||
### Llama Stack Build
|
||||
|
||||
In order to build your own distribution, we recommend you clone the `llama-stack` repository.
|
||||
|
@ -22,25 +53,25 @@ The main points to consider are:
|
|||
|
||||
```
|
||||
llama stack build -h
|
||||
|
||||
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates]
|
||||
[--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only]
|
||||
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] [--run]
|
||||
|
||||
Build a Llama stack container
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
|
||||
If this argument is not provided, you will be prompted to enter information interactively
|
||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
||||
--list-templates Show the available templates for building a Llama Stack distribution
|
||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will
|
||||
be prompted to enter information interactively (default: None)
|
||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates (default: None)
|
||||
--list-templates Show the available templates for building a Llama Stack distribution (default: False)
|
||||
--image-type {conda,container,venv}
|
||||
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.
|
||||
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config. (default:
|
||||
conda)
|
||||
--image-name IMAGE_NAME
|
||||
[for image-type=conda] Name of the conda environment to use for the build. If
|
||||
not specified, currently active Conda environment will be used. If no Conda
|
||||
environment is active, you must specify a name.
|
||||
--print-deps-only Print the dependencies for the stack only, without building the stack
|
||||
[for image-type=conda|venv] Name of the conda or virtual environment to use for the build. If not specified, currently active Conda environment will be used if
|
||||
found. (default: None)
|
||||
--print-deps-only Print the dependencies for the stack only, without building the stack (default: False)
|
||||
--run Run the stack after building using the same image type, name, and other applicable arguments (default: False)
|
||||
|
||||
```
|
||||
|
||||
After this step is complete, a file named `<name>-build.yaml` and template file `<name>-run.yaml` will be generated and saved at the output file path specified at the end of the command.
|
||||
|
@ -183,8 +214,8 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
|||
|
||||
```
|
||||
llama stack run -h
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
|
||||
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||
[--image-type {conda,container,venv}]
|
||||
config
|
||||
|
||||
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||
|
@ -194,17 +225,17 @@ positional arguments:
|
|||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321
|
||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||
--image-name IMAGE_NAME
|
||||
Name of the image to run. Defaults to the current conda environment
|
||||
--disable-ipv6 Disable IPv6 support
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
|
||||
Name of the image to run. Defaults to the current conda environment (default: None)
|
||||
--disable-ipv6 Disable IPv6 support (default: False)
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||
--tls-keyfile TLS_KEYFILE
|
||||
Path to TLS key file for HTTPS
|
||||
Path to TLS key file for HTTPS (default: None)
|
||||
--tls-certfile TLS_CERTFILE
|
||||
Path to TLS certificate file for HTTPS
|
||||
Path to TLS certificate file for HTTPS (default: None)
|
||||
--image-type {conda,container,venv}
|
||||
Image Type used during the build. This can be either conda or container or venv.
|
||||
Image Type used during the build. This can be either conda or container or venv. (default: conda)
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -17,26 +17,4 @@ $ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.a
|
|||
$ llama-stack-client models list
|
||||
```
|
||||
|
||||
You will see outputs:
|
||||
```
|
||||
$ llama-stack-client models list
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| identifier | llama_model | provider_id | metadata |
|
||||
+==============================+==============================+===============+============+
|
||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
```
|
||||
|
||||
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.
|
||||
|
|
|
@ -40,7 +40,6 @@ The following models are available by default:
|
|||
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||
|
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -184,7 +184,6 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types import Document
|
||||
|
||||
|
||||
|
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
|
|||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
rag_agent = Agent(
|
||||
client,
|
||||
model=os.environ["INFERENCE_MODEL"],
|
||||
# Define instructions for the agent ( aka system prompt)
|
||||
instructions="You are a helpful assistant",
|
||||
enable_session_persistence=False,
|
||||
# Define tools available to the agent
|
||||
toolgroups=[
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
|
@ -256,8 +256,6 @@ agent_config = AgentConfig(
|
|||
}
|
||||
],
|
||||
)
|
||||
|
||||
rag_agent = Agent(client, agent_config)
|
||||
session_id = rag_agent.create_session("test-session")
|
||||
|
||||
user_prompts = [
|
||||
|
|
|
@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
|||
| FAISS | Single Node |
|
||||
| SQLite-Vec| Single Node |
|
||||
| Chroma | Hosted and Single Node |
|
||||
| Milvus | Hosted and Single Node |
|
||||
| Postgres (PGVector) | Hosted and Single Node |
|
||||
| Weaviate | Hosted |
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
Providers come in two flavors:
|
||||
|
@ -55,5 +55,6 @@ vector_io/sqlite-vec
|
|||
vector_io/chromadb
|
||||
vector_io/pgvector
|
||||
vector_io/qdrant
|
||||
vector_io/milvus
|
||||
vector_io/weaviate
|
||||
```
|
||||
|
|
31
docs/source/providers/vector_io/mivus.md
Normal file
31
docs/source/providers/vector_io/mivus.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Milvus
|
||||
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use Milvus in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Milvus.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
```
|
||||
## Documentation
|
||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
|
@ -24,19 +24,9 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
|||
- Associated with `Benchmark` resource.
|
||||
|
||||
|
||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||

|
||||
|
||||
|
||||
```{admonition} Note on Benchmark v.s. Application Evaluation
|
||||
:class: tip
|
||||
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
|
||||
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
|
||||
```
|
||||
|
||||
## Evaluation Examples Walkthrough
|
||||
|
||||
[](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
|
||||
[](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb)
|
||||
|
||||
It is best to open this notebook in Colab to follow along with the examples.
|
||||
|
||||
|
@ -63,20 +53,29 @@ eval_rows = ds.to_pandas().to_dict(orient="records")
|
|||
- Run evaluate on the dataset
|
||||
|
||||
```python
|
||||
from rich.pretty import pprint
|
||||
from tqdm import tqdm
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """
|
||||
You are an expert in Agriculture whose job is to answer questions from the user using images.
|
||||
You are an expert in {subject} whose job is to answer questions from the user using images.
|
||||
|
||||
First, reason about the correct answer.
|
||||
|
||||
Then write the answer in the following format where X is exactly one of A,B,C,D:
|
||||
|
||||
Answer: X
|
||||
|
||||
Make sure X is one of A,B,C,D.
|
||||
|
||||
If you are uncertain of the correct answer, guess the most likely one.
|
||||
"""
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||
"content": SYSTEM_PROMPT_TEMPLATE.format(subject=subset),
|
||||
}
|
||||
|
||||
# register the evaluation benchmark task with the dataset and scoring function
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
dataset_id=f"mmmu-{subset}-{split}",
|
||||
|
@ -88,13 +87,14 @@ response = client.eval.evaluate_rows(
|
|||
input_rows=eval_rows,
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
"type": "top_p",
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.95,
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"repeat_penalty": 1.0,
|
||||
|
@ -103,6 +103,7 @@ response = client.eval.evaluate_rows(
|
|||
},
|
||||
},
|
||||
)
|
||||
pprint(response)
|
||||
```
|
||||
|
||||
#### 1.2. Running SimpleQA
|
||||
|
@ -115,10 +116,9 @@ simpleqa_dataset_id = "huggingface::simpleqa"
|
|||
_ = client.datasets.register(
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
|
||||
metadata={
|
||||
"path": "llamastack/evals",
|
||||
"name": "evals__simpleqa",
|
||||
"path": "llamastack/simpleqa",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
|
@ -146,7 +146,6 @@ response = client.eval.evaluate_rows(
|
|||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
|
@ -160,6 +159,7 @@ response = client.eval.evaluate_rows(
|
|||
},
|
||||
},
|
||||
)
|
||||
pprint(response)
|
||||
```
|
||||
|
||||
|
||||
|
@ -170,19 +170,17 @@ response = client.eval.evaluate_rows(
|
|||
|
||||
```python
|
||||
agent_config = {
|
||||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"instructions": "You are a helpful assistant",
|
||||
"model": "meta-llama/Llama-3.3-70B-Instruct",
|
||||
"instructions": "You are a helpful assistant that have access to tool to search the web. ",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "tavily",
|
||||
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
|
||||
"type": "top_p",
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
},
|
||||
"toolgroups": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
|
@ -196,24 +194,21 @@ response = client.eval.evaluate_rows(
|
|||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
benchmark_config={
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "agent",
|
||||
"config": agent_config,
|
||||
},
|
||||
},
|
||||
)
|
||||
pprint(response)
|
||||
```
|
||||
|
||||
### 3. Agentic Application Dataset Scoring
|
||||
- Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||
[](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb)
|
||||
|
||||
- In this example, we will work with an example RAG dataset and couple of scoring functions for evaluation.
|
||||
- `llm-as-judge::base`: LLM-As-Judge with custom judge prompt & model.
|
||||
- `braintrust::factuality`: Factuality scorer from [braintrust](https://github.com/braintrustdata/autoevals).
|
||||
- `basic::subset_of`: Basic checking if generated answer is a subset of expected answer.
|
||||
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||
|
||||
- Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
|
||||
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
|
||||
|
||||
```python
|
||||
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
|
||||
|
@ -280,18 +275,25 @@ response = client.scoring.score(
|
|||
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
||||
|
||||
#### Benchmark Evaluation CLI
|
||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
||||
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
|
||||
- `dataset_id`: the identifier associated with the dataset.
|
||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
||||
There are 3 necessary input for running a benchmark eval
|
||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||
- `model-id`: The model id to evaluate on
|
||||
- `utput_dir`: Path to store the evaluate results
|
||||
```
|
||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
||||
You can run
|
||||
```
|
||||
llama-stack-client eval run-benchmark help
|
||||
```
|
||||
to see the description of all the flags to run benckmark eval
|
||||
|
||||
|
||||
```
|
||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
||||
--eval-task-config ~/benchmark_config.json \
|
||||
--visualize
|
||||
```
|
||||
In the output log, you can find the path to the file that has your evaluation results. Open that file and you can see you aggrgate
|
||||
evaluation results over there.
|
||||
|
||||
|
||||
#### Application Evaluation CLI
|
||||
|
@ -317,28 +319,9 @@ The `BenchmarkConfig` are user specified config to define:
|
|||
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
||||
|
||||
|
||||
**Example Benchmark BenchmarkConfig**
|
||||
**Example BenchmarkConfig**
|
||||
```json
|
||||
{
|
||||
"type": "benchmark",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "Llama3.2-3B-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example Application BenchmarkConfig**
|
||||
```json
|
||||
{
|
||||
"type": "app",
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "Llama3.1-405B-Instruct",
|
||||
|
@ -362,3 +345,52 @@ The `BenchmarkConfig` are user specified config to define:
|
|||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Open-benchmark Contributing Guide
|
||||
|
||||
### Create the new dataset for your new benchmark
|
||||
An eval open-benchmark essentially contains 2 parts:
|
||||
- `raw data`: The raw dataset associated with the benchmark. You typically need to search the original paper that introduces the benchmark and find the canonical dataset (usually hosted on huggingface)
|
||||
- `prompt template`: How to ask the candidate model to generate the answer (prompt template plays a critical role to the evaluation results). Tyically, you can find the reference prompt template associated with the benchmark in benchmarks author's repo ([exmaple](https://github.com/idavidrein/gpqa/blob/main/prompts/chain_of_thought.txt)) or some other popular open source repos ([example](https://github.com/openai/simple-evals/blob/0a6e8f62e52bc5ae915f752466be3af596caf392/common.py#L14))
|
||||
|
||||
To create new open-benmark in llama stack, you need to combine the prompt template and the raw data into the `chat_completion_input` column in the evaluation dataset.
|
||||
|
||||
Llama stack enforeces the evaluate dataset schema to contain at least 3 columns:
|
||||
- `chat_completion_input`: The actual input to the model to run the generation for eval
|
||||
- `input_query`: The raw input from the raw dataset without the prompt template
|
||||
- `expected_answer`: The ground truth for scoring functions to calcalate the score from.
|
||||
|
||||
|
||||
You need to write a script [example convert script](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840) to convert the benchmark raw dataset to llama stack format eval dataset and update the dataset to huggingface [example benchmark dataset](https://huggingface.co/datasets/llamastack/mmmu)
|
||||
|
||||
|
||||
### Find scoring function for your new benchmark
|
||||
The purpose of scoring function is to calculate the score for each example based on candidate model generation result and expected_answer. It also aggregates the scores from all the examples and generate the final evaluate results.
|
||||
|
||||
|
||||
Firstly, you can see if the existing [llama stack scoring functions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/scoring) can fulfill your need. If not, you need to write a new scoring function based on what benchmark author / other open source repo describe.
|
||||
|
||||
### Add new benchmark into template
|
||||
Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in the [open-benchmark](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/open-benchmark/run.yaml)
|
||||
|
||||
Secondly, you need to add the new benchmark you just created under the `benchmarks` resource in the same template. To add the new benchmark, you need to have
|
||||
- `benchmark_id`: identifier of the benchmark
|
||||
- `dataset_id`: identifier of the dataset associated with your benchmark
|
||||
- `scoring_functions`: scoring function to calculate the score based on generation results and expected_answer
|
||||
|
||||
|
||||
### Test the new benchmark
|
||||
|
||||
Spin up llama stack server with 'open-benchmark' templates
|
||||
```
|
||||
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||
|
||||
```
|
||||
|
||||
Run eval benchmark CLI with your new benchmark id
|
||||
```
|
||||
llama-stack-client eval run-benchmark <new_benchmark_id> \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# llama (server-side) CLI Reference
|
||||
|
||||
The `llama` CLI tool helps you setup and use the Llama Stack. It should be available on your path after installing the `llama-stack` package.
|
||||
The `llama` CLI tool helps you set up and use the Llama Stack. The CLI is available on your path after installing the `llama-stack` package.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -27,9 +27,9 @@ You have two ways to install Llama Stack:
|
|||
|
||||
|
||||
## `llama` subcommands
|
||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||
2. `model`: Lists available models and their properties.
|
||||
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](../../distributions/building_distro).
|
||||
1. `download`: Supports downloading models from Meta or Hugging Face. [Downloading models](#downloading-models)
|
||||
2. `model`: Lists available models and their properties. [Understanding models](#understand-the-models)
|
||||
3. `stack`: Allows you to build a stack using the `llama stack` distribution and run a Llama Stack server. You can read more about how to build a Llama Stack distribution in the [Build your own Distribution](../../distributions/building_distro) documentation.
|
||||
|
||||
### Sample Usage
|
||||
|
||||
|
@ -117,7 +117,7 @@ You should see a table like this:
|
|||
+----------------------------------+------------------------------------------+----------------+
|
||||
```
|
||||
|
||||
To download models, you can use the llama download command.
|
||||
To download models, you can use the `llama download` command.
|
||||
|
||||
### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
|
||||
|
||||
|
@ -191,7 +191,7 @@ You should see a table like this:
|
|||
The `llama model` command helps you explore the model’s interface.
|
||||
|
||||
1. `download`: Download the model from different sources. (meta, huggingface)
|
||||
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
|
||||
2. `list`: Lists all the models available for download with hardware requirements for deploying the models.
|
||||
3. `prompt-format`: Show llama model message formats.
|
||||
4. `describe`: Describes all the properties of the model.
|
||||
|
||||
|
@ -262,13 +262,12 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
|||

|
||||
|
||||
|
||||
|
||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||
|
||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||
|
||||
### Remove model
|
||||
You can run `llama model remove` to remove unecessary model:
|
||||
You can run `llama model remove` to remove an unnecessary model:
|
||||
|
||||
```
|
||||
llama model remove -m Llama-Guard-3-8B-int8
|
||||
|
|
|
@ -294,8 +294,9 @@
|
|||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
"\n",
|
||||
" # Define the agent configuration, including the model and tool setup\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" # Create an agent instance with the client and configuration\n",
|
||||
" agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -303,17 +304,12 @@
|
|||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"python_list\",\n",
|
||||
" tools=[webSearchTool],\n",
|
||||
" input_shields=input_shields,\n",
|
||||
" output_shields=output_shields,\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Create an agent instance with the client and configuration\n",
|
||||
" agent = Agent(client, agent_config, [webSearchTool])\n",
|
||||
"\n",
|
||||
" # Create a session for interaction and print the session ID\n",
|
||||
" session_id = agent.create_session(\"test-session\")\n",
|
||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||
|
|
|
@ -110,12 +110,12 @@
|
|||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def agent_example():\n",
|
||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" agent = Agent(\n",
|
||||
" client, \n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -130,14 +130,7 @@
|
|||
" \"api_key\": BRAVE_SEARCH_API_KEY,\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"function_tag\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client, agent_config)\n",
|
||||
" session_id = agent.create_session(\"test-session\")\n",
|
||||
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
|
||||
"\n",
|
||||
|
|
|
@ -40,7 +40,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
ollama run llama3.2:3b-instruct-fp16 --keepalive -1m
|
||||
```
|
||||
**Note**:
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43)
|
||||
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/models.py)
|
||||
- `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again.
|
||||
|
||||
---
|
||||
|
|
|
@ -103,7 +103,6 @@
|
|||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import (\n",
|
||||
" AgentConfig,\n",
|
||||
" AgentConfigToolSearchToolDefinition,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -117,7 +116,8 @@
|
|||
") -> Agent:\n",
|
||||
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
||||
" print(\"Using the following model: \", model)\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" return Agent(\n",
|
||||
" client, \n",
|
||||
" model=model,\n",
|
||||
" instructions=instructions,\n",
|
||||
" sampling_params={\n",
|
||||
|
@ -126,12 +126,7 @@
|
|||
" },\n",
|
||||
" },\n",
|
||||
" tools=tools,\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"json\",\n",
|
||||
" enable_session_persistence=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return Agent(client, agent_config)\n"
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -360,9 +355,9 @@
|
|||
" # Create the agent with the tool\n",
|
||||
" weather_tool = WeatherTool()\n",
|
||||
"\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" agent = Agent(\n",
|
||||
" client=client, \n",
|
||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||
" # model=model_name,\n",
|
||||
" instructions=\"\"\"\n",
|
||||
" You are a weather assistant that can provide weather information.\n",
|
||||
" Always specify the location clearly in your responses.\n",
|
||||
|
@ -373,16 +368,9 @@
|
|||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[weather_tool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"json\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=True,\n",
|
||||
" tools=[weather_tool],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
|
||||
"\n",
|
||||
" return agent\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
|
|
@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
|
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
|
|||
|
||||
|
||||
class StepType(Enum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
|
@ -66,6 +94,11 @@ class StepType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
|
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
violation: Optional[SafetyViolation]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
@ -148,7 +199,7 @@ AgentToolGroup = register_schema(
|
|||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
@ -183,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
@ -302,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponseMessage]
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -335,7 +403,13 @@ class Agents(Protocol):
|
|||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
async def create_agent_turn(
|
||||
|
@ -352,7 +426,19 @@ class Agents(Protocol):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
"""
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
|
@ -363,7 +449,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
@ -374,6 +460,7 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
|
@ -388,7 +475,15 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
|
@ -400,14 +495,30 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse: ...
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse: ...
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||
async def get_agents_session(
|
||||
|
@ -415,17 +526,64 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -40,7 +40,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
@ -50,7 +50,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
@ -14,6 +14,14 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class PaginatedRowsResult(BaseModel):
|
||||
"""
|
||||
A paginated list of rows from a dataset.
|
||||
|
||||
:param rows: The rows in the current page.
|
||||
:param total_count: The total number of rows in the dataset.
|
||||
:param next_page_token: The token to get the next page of rows.
|
||||
"""
|
||||
|
||||
# the rows obey the DatasetSchema for the given dataset
|
||||
rows: List[Dict[str, Any]]
|
||||
total_count: int
|
||||
|
@ -36,7 +44,15 @@ class DatasetIO(Protocol):
|
|||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
) -> PaginatedRowsResult:
|
||||
"""Get a paginated list of rows from a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param rows_in_page: The number of rows to get per page.
|
||||
:param page_token: The token to get the next page of rows.
|
||||
:param filter_condition: (Optional) A condition to filter the rows by.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
|
|
|
@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
|
@ -27,6 +34,11 @@ class ModelCandidate(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
@ -39,6 +51,13 @@ EvalCandidate = register_schema(
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
@ -53,18 +72,32 @@ class BenchmarkConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
|
@ -73,13 +106,40 @@ class Eval(Protocol):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
|
|
|
@ -278,7 +278,7 @@ ResponseFormat = register_schema(
|
|||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
|
|||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
|
@ -444,7 +444,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -467,7 +467,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
|
@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
|
|||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
"""
|
||||
A scoring result for a single row.
|
||||
|
||||
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
||||
:param aggregated_results: Map of metric name to aggregated value
|
||||
"""
|
||||
|
||||
score_rows: List[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
|
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
"""
|
||||
The response from scoring.
|
||||
|
||||
:param results: A map of scoring function name to ScoringResult.
|
||||
"""
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
@ -55,4 +68,11 @@ class Scoring(Protocol):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse: ...
|
||||
) -> ScoreResponse:
|
||||
"""Score a list of rows.
|
||||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -64,7 +64,7 @@ class ModelDescribe(Subcommand):
|
|||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
sampling_params = model.recommended_sampling_params.dict()
|
||||
sampling_params = model.recommended_sampling_params.model_dump()
|
||||
for k in ("max_tokens", "repetition_penalty"):
|
||||
del sampling_params[k]
|
||||
rows.append(
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.cli.subcommand import Subcommand
|
|||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
|
@ -44,6 +44,12 @@ class ModelPromptFormat(Subcommand):
|
|||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="List all available models",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
|
|
|
@ -39,7 +39,7 @@ from llama_stack.distribution.resolver import InvalidProviderError
|
|||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
|
@ -226,7 +226,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
|
|
|
@ -16,7 +16,7 @@ class StackBuild(Subcommand):
|
|||
"build",
|
||||
prog="llama stack build",
|
||||
description="Build a Llama stack container",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_build_command)
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
@ -23,7 +23,7 @@ class StackRun(Subcommand):
|
|||
"run",
|
||||
prog="llama stack run",
|
||||
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_run_cmd)
|
||||
|
@ -37,12 +37,13 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current conda environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
@ -79,12 +80,8 @@ class StackRun(Subcommand):
|
|||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
BUILDS_BASE_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
|
||||
config_file = Path(args.config)
|
||||
|
@ -97,14 +94,6 @@ class StackRun(Subcommand):
|
|||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to conda dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to container dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -95,7 +95,7 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -104,7 +104,7 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
@ -112,7 +112,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
|
@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
|
|||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
config=cfg.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -160,6 +163,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
|
@ -262,21 +268,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
if self.provider_data:
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
@ -374,9 +384,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
|
@ -4,16 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
@ -26,7 +72,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||
val = _provider_data_var.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
@ -36,25 +82,32 @@ class NeedsRequestProviderData:
|
|||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
|
|
|
@ -7,7 +7,6 @@ import importlib
|
|||
import inspect
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
|
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
@ -184,7 +186,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
@ -206,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
@ -244,9 +245,10 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -52,8 +51,11 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
"""Routes to an provider based on the vector db identifier"""
|
||||
|
@ -62,15 +64,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing VectorIORouter")
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.initialize")
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.shutdown")
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
@ -81,10 +83,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -99,8 +98,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -111,7 +109,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
@ -122,15 +120,15 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing InferenceRouter")
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.initialize")
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.shutdown")
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
@ -141,8 +139,7 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
@ -151,7 +148,7 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
|
@ -160,10 +157,11 @@ class InferenceRouter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
@ -217,13 +215,14 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
logcat.debug(
|
||||
"core",
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
@ -253,7 +252,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
@ -273,15 +272,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing SafetyRouter")
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.initialize")
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.shutdown")
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
@ -291,7 +290,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
@ -300,7 +299,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
@ -313,15 +312,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing DatasetIORouter")
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.initialize")
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
|
@ -331,8 +330,7 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
|
@ -343,7 +341,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
@ -355,15 +353,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ScoringRouter")
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.initialize")
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.shutdown")
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
@ -372,7 +370,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
@ -393,10 +391,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||
)
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
@ -414,15 +409,15 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing EvalRouter")
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.initialize")
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.shutdown")
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
|
@ -430,7 +425,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
|
@ -443,7 +438,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
@ -456,7 +451,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
@ -464,7 +459,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -475,7 +470,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -488,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
@ -497,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
@ -508,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
|
@ -520,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
@ -529,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
@ -546,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -6,12 +6,9 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
@ -28,10 +25,12 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -39,6 +38,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
|
@ -54,8 +54,7 @@ from .endpoints import get_all_api_endpoints
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logcat.init()
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -117,78 +116,32 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logcat.info("server", f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logcat.info("server", "Starting up")
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
logcat.info("server", "Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -204,15 +157,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logcat.info("server", "Generator cancelled")
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logcat.exception("server", "Error in sse_generator")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -224,18 +176,20 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in {func.__name__}")
|
||||
raise translate_exception(e) from e
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
@ -264,7 +218,7 @@ class TracingMiddleware:
|
|||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
|
@ -313,8 +267,6 @@ class ClientVersionMiddleware:
|
|||
|
||||
|
||||
def main():
|
||||
logcat.init()
|
||||
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
@ -354,10 +306,10 @@ def main():
|
|||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
|
@ -365,12 +317,12 @@ def main():
|
|||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logcat.info("server", f"Using config file: {config_file}")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
|
@ -378,10 +330,9 @@ def main():
|
|||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
logcat.info("server", "Run configuration:")
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||
logcat.info("server", log_line)
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
@ -391,7 +342,7 @@ def main():
|
|||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
|
@ -436,12 +387,10 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
@ -463,15 +412,16 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
|
|
@ -7,12 +7,11 @@
|
|||
import importlib.resources
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
@ -33,12 +32,16 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
VectorDBs,
|
||||
|
@ -99,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
||||
|
@ -228,3 +230,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
|||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> StackRunConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||
"""
|
||||
|
||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||
provider_registry = provider_registry or get_provider_registry()
|
||||
|
||||
distro_dir = tempfile.mkdtemp()
|
||||
provider_configs_by_api = {}
|
||||
for api_provider in api_providers:
|
||||
api_str, provider = api_provider.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
providers_by_type = provider_registry[api]
|
||||
provider_spec = providers_by_type.get(provider)
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||
|
||||
if not provider_spec:
|
||||
raise ValueError(
|
||||
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||
)
|
||||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=provider,
|
||||
provider_type=provider_spec.provider_type,
|
||||
config=provider_config,
|
||||
)
|
||||
]
|
||||
config = StackRunConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
)
|
||||
return config
|
||||
|
|
|
@ -100,12 +100,15 @@ esac
|
|||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
|
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
|
|
|
@ -17,7 +17,7 @@ llama stack run together
|
|||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
$ llama-stack-client datasets register \
|
||||
llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
|
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
|
|||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client benchmarks register \
|
||||
llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
@ -124,13 +123,14 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
toolgroups=[
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
|
@ -138,12 +138,7 @@ def rag_chat_page():
|
|||
},
|
||||
)
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
agent = Agent(llama_stack_api.client, agent_config)
|
||||
session_id = agent.create_session("rag-session")
|
||||
|
||||
# Chat input
|
||||
|
|
|
@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
|
|
@ -20,14 +20,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
198
llama_stack/log.py
Normal file
198
llama_stack/log.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Predefined categories
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
||||
Parameters:
|
||||
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
"""Override emit to handle markup errors gracefully."""
|
||||
try:
|
||||
super().emit(record)
|
||||
except MarkupError:
|
||||
original_markup = self.markup
|
||||
self.markup = False
|
||||
try:
|
||||
super().emit(record)
|
||||
finally:
|
||||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
class CategoryFilter(logging.Filter):
|
||||
"""Ensure category is always present in log records."""
|
||||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {
|
||||
"()": logging.Formatter,
|
||||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
||||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
|
@ -1,204 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Category-based logging utility for llama-stack.
|
||||
|
||||
This module provides a wrapper over the standard Python logging module that supports
|
||||
categorized logging with environment variable control.
|
||||
|
||||
Usage:
|
||||
from llama_stack import logcat
|
||||
logcat.info("server", "Starting up...")
|
||||
logcat.debug("inference", "Processing request...")
|
||||
|
||||
Environment variable:
|
||||
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||
Example: "server=debug;inference=warning"
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
COLORS = {
|
||||
"RESET": "\033[0m",
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"DIM": "\033[2m", # Dimmed text
|
||||
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||
}
|
||||
|
||||
# Static list of valid categories representing various parts of the Llama Stack
|
||||
# server codebase
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger("llama_stack")
|
||||
_logger.propagate = False
|
||||
|
||||
_default_level = logging.INFO
|
||||
|
||||
# Category-level mapping (can be modified by environment variables)
|
||||
_category_levels: Dict[str, int] = {}
|
||||
|
||||
|
||||
class TerminalStreamHandler(logging.StreamHandler):
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||
|
||||
def format(self, record):
|
||||
record.is_tty = self.is_tty
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter with colors and fixed-width level names"""
|
||||
|
||||
def format(self, record):
|
||||
levelname = record.levelname
|
||||
# Use only time with milliseconds, not date
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||
|
||||
file_info = f"{record.filename}:{record.lineno}"
|
||||
|
||||
# Get category from extra if available
|
||||
category = getattr(record, "category", None)
|
||||
msg = record.getMessage()
|
||||
|
||||
if getattr(record, "is_tty", False):
|
||||
color = COLORS.get(levelname, COLORS["RESET"])
|
||||
if category:
|
||||
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||
f"{file_info:<20} {category_formatted}{msg}"
|
||||
)
|
||||
else:
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||
f"{file_info:<20} {msg}"
|
||||
)
|
||||
else:
|
||||
if category:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||
else:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||
|
||||
return formatted_msg
|
||||
|
||||
|
||||
def init(default_level: int = logging.INFO) -> None:
|
||||
global _default_level, _category_levels, _logger
|
||||
|
||||
_default_level = default_level
|
||||
|
||||
_logger.setLevel(logging.DEBUG)
|
||||
_logger.handlers = [] # Clear existing handlers
|
||||
|
||||
# Add our custom handler with the colored formatter
|
||||
handler = TerminalStreamHandler()
|
||||
formatter = ColoredFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
_logger.addHandler(handler)
|
||||
|
||||
for category in CATEGORIES:
|
||||
_category_levels[category] = default_level
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().lower()
|
||||
|
||||
level_value = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}.get(level)
|
||||
|
||||
if level_value is None:
|
||||
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
for cat in CATEGORIES:
|
||||
_category_levels[cat] = level_value
|
||||
else:
|
||||
if category in CATEGORIES:
|
||||
_category_levels[category] = level_value
|
||||
else:
|
||||
_logger.warning(f"Unknown logging category: {category}")
|
||||
|
||||
except ValueError:
|
||||
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||
|
||||
|
||||
def _should_log(level: int, category: str) -> bool:
|
||||
category = category.lower()
|
||||
if category not in _category_levels:
|
||||
return False
|
||||
category_level = _category_levels[category]
|
||||
return level >= category_level
|
||||
|
||||
|
||||
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(level, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||
|
||||
|
||||
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||
warning(category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(logging.ERROR, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -17,7 +17,6 @@ from urllib.parse import urlparse
|
|||
|
||||
import httpx
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
|
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
|
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
|
@ -179,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
@ -189,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
@ -216,13 +218,25 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
messages.extend(request.tool_responses)
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
last_turn_messages.extend(request.tool_responses)
|
||||
last_turn_messages.extend(tool_response_messages)
|
||||
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
@ -238,14 +252,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in request.tool_responses
|
||||
],
|
||||
tool_responses=tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
@ -383,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
@ -501,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -604,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
|
@ -612,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
logcat.info("agents", "out of token budget, exiting.")
|
||||
logger.info("out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -626,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -684,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
@ -978,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||
logger.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
@ -1018,7 +1019,7 @@ async def execute_tool_call_maybe(
|
|||
else:
|
||||
name = name.value
|
||||
|
||||
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
|
@ -1028,7 +1029,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
@ -21,12 +22,15 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -83,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
@ -119,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
@ -159,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
@ -168,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
@ -187,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
@ -209,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
@ -231,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
@ -1,411 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL, TextDelta
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolGroupsResponse,
|
||||
ListToolsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolHost,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text="AI is a fascinating field..."),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_response()
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant",
|
||||
content="Mock response",
|
||||
stop_reason="end_of_turn",
|
||||
),
|
||||
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
class MockSafetyAPI:
|
||||
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
||||
class MockVectorIOAPI:
|
||||
def __init__(self):
|
||||
self.chunks = {}
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
|
||||
for chunk in chunks:
|
||||
metadata = chunk.metadata
|
||||
self.chunks[vector_db_id][metadata["document_id"]] = chunk
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params=None):
|
||||
if vector_db_id not in self.chunks:
|
||||
raise ValueError(f"Bank {vector_db_id} not found")
|
||||
|
||||
chunks = list(self.chunks[vector_db_id].values())
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MockToolGroupsAPI:
|
||||
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return ToolGroup(
|
||||
identifier=toolgroup_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=[])
|
||||
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
if toolgroup_id == MEMORY_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::rag",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
return ListToolsResponse(data=[])
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
identifier=tool_name,
|
||||
provider_resource_id=tool_name,
|
||||
toolgroup_id="mock_group",
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="mock_provider",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, toolgroup_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockToolRuntimeAPI:
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
return MockSafetyAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_io_api():
|
||||
return MockVectorIOAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_groups_api():
|
||||
return MockToolGroupsAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
return MockToolRuntimeAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_agents_impl(
|
||||
mock_inference_api,
|
||||
mock_safety_api,
|
||||
mock_vector_io_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_tool_groups_api,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_name=sqlite_file.name,
|
||||
),
|
||||
),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent(get_agents_impl):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
MEMORY_TOOLGROUP = "builtin::rag"
|
||||
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||
impl = await get_agents_impl
|
||||
toolgroups = request.param
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
assert (
|
||||
len(responses) == 7
|
||||
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
async for chunk in chat_agent.run_multiple_shields_wrapper(
|
||||
turn_id="test_turn_id",
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
touchpoint="user-input",
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
|
||||
step_types = [
|
||||
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||
assert StepType.inference in step_types, "Inference step is missing"
|
||||
|
||||
event_types = [
|
||||
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "turn_start" in event_types, "Start event is missing"
|
||||
assert "turn_complete" in event_types, "Complete event is missing"
|
||||
|
||||
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
|
||||
"Turn complete event is missing"
|
||||
)
|
||||
turn_complete_payload = next(
|
||||
response.event.payload
|
||||
for response in responses
|
||||
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
)
|
||||
turn = turn_complete_payload.turn
|
||||
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroups, expected_memory, expected_code_interpreter",
|
||||
[
|
||||
([], False, False), # no tools
|
||||
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||
],
|
||||
)
|
||||
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
tool_defs_names = [t.tool_name for t in tool_defs]
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs_names
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs_names
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
toolgroups_for_turn=[
|
||||
AgentToolGroupWithArgs(
|
||||
name=MEMORY_TOOLGROUP,
|
||||
args={"vector_dbs": ["test_vector_db"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs_names
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs_names
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
@ -4,20 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider."""
|
||||
"""Configuration for the vLLM inference provider.
|
||||
|
||||
Note that the model name is no longer part of this static configuration.
|
||||
You can bind an instance of this provider to a specific model with the
|
||||
``models.register()`` API call."""
|
||||
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
|
@ -26,32 +25,27 @@ class VLLMConfig(BaseModel):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||
return model
|
||||
|
|
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
|
@ -4,45 +4,71 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
InterleavedContentItem,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -50,188 +76,322 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
|
||||
def _random_uuid() -> str:
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""Inference implementation for vLLM."""
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
model = resolve_model(self.config.model)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown model {self.config.model}")
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
if model.huggingface_repo is None:
|
||||
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model.huggingface_repo,
|
||||
tokenizer=model.huggingface_repo,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
)
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shut down the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference provider.")
|
||||
if self.engine:
|
||||
async def shutdown(self) -> None:
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.model_ids = set()
|
||||
self.resolved_model_id = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint
|
||||
with an inference provider.
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying
|
||||
a specific LLM.
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible
|
||||
to change fields before returning this object.
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||
# The current version of this provided is hard-coded to serve only
|
||||
# the model specified in the YAML config file.
|
||||
configured_model = resolve_model(self.config.model)
|
||||
registered_model = resolve_model(model.model_id)
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
model_id_for_vllm = model.provider_model_id
|
||||
is_meta_llama_model = False
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if model_id_for_vllm != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
logger.info(
|
||||
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||
)
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||
if is_meta_llama_model:
|
||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_id_for_vllm,
|
||||
tokenizer=model_id_for_vllm,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
models=OpenAIServingModels(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
||||
],
|
||||
),
|
||||
response_role="assistant",
|
||||
request_logger=None, # Use default logging
|
||||
chat_template=None, # Use default template from model checkpoint
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = model_id_for_vllm
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
||||
|
||||
if configured_model.core_model_id != registered_model.core_model_id:
|
||||
raise ValueError(
|
||||
f"Requested model '{model.identifier}' is different from "
|
||||
f"model '{self.config.model}' that this provider "
|
||||
f"is configured to serve"
|
||||
)
|
||||
return model
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
options = get_sampling_options(sampling_params)
|
||||
if "repeat_penalty" in options:
|
||||
options["repetition_penalty"] = options["repeat_penalty"]
|
||||
del options["repeat_penalty"]
|
||||
|
||||
return VLLMSamplingParams(**options)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
assert self.engine is not None
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> ChatCompletionResponse:
|
||||
outputs = [o async for o in results_generator]
|
||||
final_output = outputs[-1]
|
||||
|
||||
assert final_output is not None
|
||||
outputs = final_output.outputs
|
||||
finish_reason = outputs[-1].stop_reason
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=finish_reason,
|
||||
text="".join([output.text for output in outputs]),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=output.finish_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
streaming_result = None
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
stop_reason=streaming_result.stop_reason,
|
||||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -242,3 +402,391 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
return await self._chat_completion_for_meta_llama(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
||||
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
||||
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
||||
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
||||
llm_engine = self.engine.engine
|
||||
tokenizer_group = llm_engine.tokenizer
|
||||
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
||||
|
||||
request_output: vllm.RequestOutput = None
|
||||
async for request_output in results_generator:
|
||||
# Check for weird inference failures
|
||||
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||
if output.stop_reason is None:
|
||||
stop_reason = None # Still going
|
||||
elif output.stop_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif output.stop_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif isinstance(output.stop_reason, int):
|
||||
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||
# will return the token ID of the EOS token in the stop_reason field.
|
||||
stop_reason = StopReason.end_of_turn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=request.stream,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||
)
|
||||
completion_response: CompletionResponse = completion_response_or_iterator
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
method to keep asyncio happy.
|
||||
"""
|
||||
|
||||
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||
last_text_len = 0
|
||||
async for chunk in results_iterator:
|
||||
if chunk.stop_reason == StopReason.end_of_turn:
|
||||
finish_reason = "stop"
|
||||
elif chunk.stop_reason == StopReason.end_of_message:
|
||||
finish_reason = "eos"
|
||||
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
|
||||
# Convert delta back to an actual delta
|
||||
text_delta = chunk.delta[last_text_len:]
|
||||
last_text_len = len(chunk.delta)
|
||||
|
||||
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"The best answer is ",
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
|
@ -3,11 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from .config import LlmAsJudgeScoringConfig
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
LLM_JUDGE_FN = LlmAsJudgeScoringFn
|
||||
|
||||
|
||||
class LlmAsJudgeScoringImpl(
|
||||
|
@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl(
|
|||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.inference_api = inference_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for fn in LLM_JUDGE_FNS:
|
||||
impl = fn(inference_api=self.inference_api)
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
self.llm_as_judge_fn = impl
|
||||
impl = LLM_JUDGE_FN(inference_api=self.inference_api)
|
||||
self.llm_as_judge_fn = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
assert f.identifier.startswith("llm-as-judge"), (
|
||||
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
)
|
||||
|
@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl(
|
|||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
|
@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl(
|
|||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn = self.llm_as_judge_fn
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
judge_response = await self.inference_api.chat_completion(
|
||||
model_id=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
UserMessage(
|
||||
content=judge_input_msg,
|
||||
),
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue