mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 11:47:00 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
b1d941e1f0
447 changed files with 6462 additions and 64778 deletions
26
.github/actions/setup-ollama/action.yml
vendored
Normal file
26
.github/actions/setup-ollama/action.yml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
name: Setup Ollama
|
||||||
|
description: Start Ollama and cache model
|
||||||
|
inputs:
|
||||||
|
models:
|
||||||
|
description: Comma-separated list of models to pull
|
||||||
|
default: "llama3.2:3b-instruct-fp16,all-minilm:latest"
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install and start Ollama
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# the ollama installer also starts the ollama service
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
|
||||||
|
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
|
||||||
|
# pull them directly.
|
||||||
|
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
|
||||||
|
- name: Pull requested models
|
||||||
|
if: inputs.models != ''
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
for model in $(echo "${{ inputs.models }}" | tr ',' ' '); do
|
||||||
|
ollama pull "$model"
|
||||||
|
done
|
15
.github/workflows/integration-tests.yml
vendored
15
.github/workflows/integration-tests.yml
vendored
|
@ -38,19 +38,8 @@ jobs:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
||||||
- name: Install and start Ollama
|
- name: Setup ollama
|
||||||
run: |
|
uses: ./.github/actions/setup-ollama
|
||||||
# the ollama installer also starts the ollama service
|
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
|
||||||
|
|
||||||
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
|
|
||||||
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
|
|
||||||
# pull them directly.
|
|
||||||
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
|
|
||||||
- name: Pull Ollama models (instruct and embed)
|
|
||||||
run: |
|
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
|
||||||
ollama pull all-minilm:latest
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
|
@ -27,6 +27,8 @@ jobs:
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
|
env:
|
||||||
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
||||||
|
|
2
.github/workflows/providers-build.yml
vendored
2
.github/workflows/providers-build.yml
vendored
|
@ -153,7 +153,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,18 @@ repos:
|
||||||
args: ['--maxkb=1000']
|
args: ['--maxkb=1000']
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
exclude: '^(.*\.svg)$'
|
exclude: '^(.*\.svg)$'
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
- id: check-yaml
|
||||||
|
args: ["--unsafe"]
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
- id: mixed-line-ending
|
||||||
|
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
||||||
|
- id: check-executables-have-shebangs
|
||||||
|
- id: check-json
|
||||||
|
- id: check-shebang-scripts-are-executable
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: check-toml
|
||||||
|
|
||||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
rev: v1.5.4
|
rev: v1.5.4
|
||||||
|
|
594
CHANGELOG.md
594
CHANGELOG.md
|
@ -3,28 +3,28 @@
|
||||||
# v0.2.3
|
# v0.2.3
|
||||||
Published on: 2025-04-25T22:46:21Z
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
## Highlights
|
## Highlights
|
||||||
|
|
||||||
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
||||||
* significant improvements and functionality added to the nVIDIA distribution
|
* significant improvements and functionality added to the nVIDIA distribution
|
||||||
* many improvements to the test verification suite.
|
* many improvements to the test verification suite.
|
||||||
* new inference providers: Ramalama, IBM WatsonX
|
* new inference providers: Ramalama, IBM WatsonX
|
||||||
* many improvements to the Playground UI
|
* many improvements to the Playground UI
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.2.2
|
# v0.2.2
|
||||||
Published on: 2025-04-13T01:19:49Z
|
Published on: 2025-04-13T01:19:49Z
|
||||||
|
|
||||||
## Main changes
|
## Main changes
|
||||||
|
|
||||||
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
||||||
- OpenAI compatible inference API in progress (@bbrowning)
|
- OpenAI compatible inference API in progress (@bbrowning)
|
||||||
- Provider verifications (@ehhuang)
|
- Provider verifications (@ehhuang)
|
||||||
- Many updates and fixes to playground
|
- Many updates and fixes to playground
|
||||||
- Several llama4 related fixes
|
- Several llama4 related fixes
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ Published on: 2025-04-05T23:13:00Z
|
||||||
# v0.2.0
|
# v0.2.0
|
||||||
Published on: 2025-04-05T19:04:29Z
|
Published on: 2025-04-05T19:04:29Z
|
||||||
|
|
||||||
## Llama 4 Support
|
## Llama 4 Support
|
||||||
|
|
||||||
Checkout more at https://www.llama.com
|
Checkout more at https://www.llama.com
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -49,58 +49,58 @@ Checkout more at https://www.llama.com
|
||||||
# v0.1.9
|
# v0.1.9
|
||||||
Published on: 2025-03-29T00:52:23Z
|
Published on: 2025-03-29T00:52:23Z
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Agents: Entire document context with attachments
|
* Agents: Entire document context with attachments
|
||||||
* RAG: Documentation with sqlite-vec faiss comparison
|
* RAG: Documentation with sqlite-vec faiss comparison
|
||||||
* Getting started: Fixes to getting started notebook.
|
* Getting started: Fixes to getting started notebook.
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* (**New**) Post-training: Add nemo customizer
|
* (**New**) Post-training: Add nemo customizer
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Moved sqlite-vec to non-blocking calls
|
* Moved sqlite-vec to non-blocking calls
|
||||||
* Don't return a payload on file delete
|
* Don't return a payload on file delete
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.8
|
# v0.1.8
|
||||||
Published on: 2025-03-24T01:28:50Z
|
Published on: 2025-03-24T01:28:50Z
|
||||||
|
|
||||||
# v0.1.8 Release Notes
|
# v0.1.8 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Safety: Integrated NVIDIA as a safety provider.
|
* Safety: Integrated NVIDIA as a safety provider.
|
||||||
* VectorDB: Added Qdrant as an inline provider.
|
* VectorDB: Added Qdrant as an inline provider.
|
||||||
* Agents: Added support for multiple tool groups in agents.
|
* Agents: Added support for multiple tool groups in agents.
|
||||||
* Agents: Simplified imports for Agents in client package
|
* Agents: Simplified imports for Agents in client package
|
||||||
|
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Introduced DocVQA and IfEval benchmarks.
|
* Introduced DocVQA and IfEval benchmarks.
|
||||||
|
|
||||||
### Deploying and Monitoring Agents
|
### Deploying and Monitoring Agents
|
||||||
* Introduced a Containerfile and image workflow for the Playground.
|
* Introduced a Containerfile and image workflow for the Playground.
|
||||||
* Implemented support for Bearer (API Key) authentication.
|
* Implemented support for Bearer (API Key) authentication.
|
||||||
* Added attribute-based access control for resources.
|
* Added attribute-based access control for resources.
|
||||||
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
|
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
|
||||||
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
|
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Consolidated scripts under the ./scripts directory.
|
* Consolidated scripts under the ./scripts directory.
|
||||||
* Addressed mypy violations in various modules.
|
* Addressed mypy violations in various modules.
|
||||||
* Added Dependabot scans for Python dependencies.
|
* Added Dependabot scans for Python dependencies.
|
||||||
* Implemented a scheduled workflow to update the changelog automatically.
|
* Implemented a scheduled workflow to update the changelog automatically.
|
||||||
* Enforced concurrency to reduce CI loads.
|
* Enforced concurrency to reduce CI loads.
|
||||||
|
|
||||||
|
|
||||||
### New Contributors
|
### New Contributors
|
||||||
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
|
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
|
||||||
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
|
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
|
||||||
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
|
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
|
||||||
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
|
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -108,73 +108,73 @@ Published on: 2025-03-24T01:28:50Z
|
||||||
# v0.1.7
|
# v0.1.7
|
||||||
Published on: 2025-03-14T22:30:51Z
|
Published on: 2025-03-14T22:30:51Z
|
||||||
|
|
||||||
## 0.1.7 Release Notes
|
## 0.1.7 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: ImageType is now refactored to LlamaStackImageType
|
* Inference: ImageType is now refactored to LlamaStackImageType
|
||||||
* Inference: Added tests to measure TTFT
|
* Inference: Added tests to measure TTFT
|
||||||
* Inference: Bring back usage metrics
|
* Inference: Bring back usage metrics
|
||||||
* Agents: Added endpoint for get agent, list agents and list sessions
|
* Agents: Added endpoint for get agent, list agents and list sessions
|
||||||
* Agents: Automated conversion of type hints in client tool for lite llm format
|
* Agents: Automated conversion of type hints in client tool for lite llm format
|
||||||
* Agents: Deprecated ToolResponseMessage in agent.resume API
|
* Agents: Deprecated ToolResponseMessage in agent.resume API
|
||||||
* Added Provider API for listing and inspecting provider info
|
* Added Provider API for listing and inspecting provider info
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Eval: Added new eval benchmarks Math 500 and BFCL v3
|
* Eval: Added new eval benchmarks Math 500 and BFCL v3
|
||||||
* Deploy and Monitoring of Agents
|
* Deploy and Monitoring of Agents
|
||||||
* Telemetry: Fix tracing to work across coroutines
|
* Telemetry: Fix tracing to work across coroutines
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Display code coverage for unit tests
|
* Display code coverage for unit tests
|
||||||
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
|
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
|
||||||
* Unit tests also run on Python 3.11, 3.12, and 3.13
|
* Unit tests also run on Python 3.11, 3.12, and 3.13
|
||||||
* Added ollama inference to Integration tests CI
|
* Added ollama inference to Integration tests CI
|
||||||
* Improved documentation across examples, testing, CLI, updated providers table )
|
* Improved documentation across examples, testing, CLI, updated providers table )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.6
|
# v0.1.6
|
||||||
Published on: 2025-03-08T04:35:08Z
|
Published on: 2025-03-08T04:35:08Z
|
||||||
|
|
||||||
## 0.1.6 Release Notes
|
## 0.1.6 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: Fixed support for inline vllm provider
|
* Inference: Fixed support for inline vllm provider
|
||||||
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
||||||
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
||||||
* Agent: Unify tools and Python SDK Agents API
|
* Agent: Unify tools and Python SDK Agents API
|
||||||
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
||||||
* Agent: Support python functions without @client_tool decorator as client tools
|
* Agent: Support python functions without @client_tool decorator as client tools
|
||||||
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
||||||
* VectorIO: MilvusDB support added
|
* VectorIO: MilvusDB support added
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
||||||
* Eval: Documentation for eval, scoring, adding new benchmarks
|
* Eval: Documentation for eval, scoring, adding new benchmarks
|
||||||
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
||||||
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
||||||
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
||||||
|
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
* Better support for different log levels across all components for better monitoring
|
* Better support for different log levels across all components for better monitoring
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Enhance OpenAPI spec to include Error types across all APIs
|
* Enhance OpenAPI spec to include Error types across all APIs
|
||||||
* Moved all tests to /tests and created unit tests to run on each PR
|
* Moved all tests to /tests and created unit tests to run on each PR
|
||||||
* Removed all dependencies on llama-models repo
|
* Removed all dependencies on llama-models repo
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.5.1
|
# v0.1.5.1
|
||||||
Published on: 2025-02-28T22:37:44Z
|
Published on: 2025-02-28T22:37:44Z
|
||||||
|
|
||||||
## 0.1.5.1 Release Notes
|
## 0.1.5.1 Release Notes
|
||||||
* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328
|
* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -182,176 +182,176 @@ Published on: 2025-02-28T22:37:44Z
|
||||||
# v0.1.5
|
# v0.1.5
|
||||||
Published on: 2025-02-28T18:14:01Z
|
Published on: 2025-02-28T18:14:01Z
|
||||||
|
|
||||||
## 0.1.5 Release Notes
|
## 0.1.5 Release Notes
|
||||||
### Build Agents
|
### Build Agents
|
||||||
* Inference: Support more non-llama models (openai, anthropic, gemini)
|
* Inference: Support more non-llama models (openai, anthropic, gemini)
|
||||||
* Inference: Can use the provider's model name in addition to the HF alias
|
* Inference: Can use the provider's model name in addition to the HF alias
|
||||||
* Inference: Fixed issues with calling tools that weren't specified in the prompt
|
* Inference: Fixed issues with calling tools that weren't specified in the prompt
|
||||||
* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling
|
* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling
|
||||||
* Embeddings: Added support for Nemo retriever embedding models
|
* Embeddings: Added support for Nemo retriever embedding models
|
||||||
* Tools: Added support for MCP tools in Ollama Distribution
|
* Tools: Added support for MCP tools in Ollama Distribution
|
||||||
* Distributions: Added new Groq distribution
|
* Distributions: Added new Groq distribution
|
||||||
|
|
||||||
### Customize Models
|
### Customize Models
|
||||||
* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model
|
* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model
|
||||||
|
|
||||||
### Monitor agents
|
### Monitor agents
|
||||||
* More comprehensive logging of agent steps including client tools
|
* More comprehensive logging of agent steps including client tools
|
||||||
* Telemetry inputs/outputs are now structured and queryable
|
* Telemetry inputs/outputs are now structured and queryable
|
||||||
* Ability to retrieve agents session, turn, step by ids
|
* Ability to retrieve agents session, turn, step by ids
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin
|
* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin
|
||||||
* Move most logging to use logger instead of prints
|
* Move most logging to use logger instead of prints
|
||||||
* Completed text /chat-completion and /completion tests
|
* Completed text /chat-completion and /completion tests
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.4
|
# v0.1.4
|
||||||
Published on: 2025-02-25T00:02:43Z
|
Published on: 2025-02-25T00:02:43Z
|
||||||
|
|
||||||
## v0.1.4 Release Notes
|
## v0.1.4 Release Notes
|
||||||
Here are the key changes coming as part of this release:
|
Here are the key changes coming as part of this release:
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: Added support for non-llama models
|
* Inference: Added support for non-llama models
|
||||||
* Inference: Added option to list all downloaded models and remove models
|
* Inference: Added option to list all downloaded models and remove models
|
||||||
* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn
|
* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn
|
||||||
* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides
|
* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides
|
||||||
* Agent: Added logging for agent step start and completion times
|
* Agent: Added logging for agent step start and completion times
|
||||||
* Agent: Added support for logging for tool execution metadata
|
* Agent: Added support for logging for tool execution metadata
|
||||||
* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs
|
* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs
|
||||||
* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults
|
* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults
|
||||||
* VectorIO: Improved performance of sqlite-vec using chunked writes
|
* VectorIO: Improved performance of sqlite-vec using chunked writes
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Deprecated api /eval-tasks. Use /eval/benchmark instead
|
* Deprecated api /eval-tasks. Use /eval/benchmark instead
|
||||||
* Added CPU training support for TorchTune
|
* Added CPU training support for TorchTune
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
* Consistent view of client and server tool calls in telemetry
|
* Consistent view of client and server tool calls in telemetry
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Made tests more data-driven for consistent evaluation
|
* Made tests more data-driven for consistent evaluation
|
||||||
* Fixed documentation links and improved API reference generation
|
* Fixed documentation links and improved API reference generation
|
||||||
* Various small fixes for build scripts and system reliability
|
* Various small fixes for build scripts and system reliability
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.3
|
# v0.1.3
|
||||||
Published on: 2025-02-14T20:24:32Z
|
Published on: 2025-02-14T20:24:32Z
|
||||||
|
|
||||||
## v0.1.3 Release
|
## v0.1.3 Release
|
||||||
|
|
||||||
Here are some key changes that are coming as part of this release.
|
Here are some key changes that are coming as part of this release.
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
Streamlined the initial development experience
|
Streamlined the initial development experience
|
||||||
- Added support for llama stack run --image-type venv
|
- Added support for llama stack run --image-type venv
|
||||||
- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration
|
- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration
|
||||||
- vLLM improvements for tool calling and logprobs
|
- vLLM improvements for tool calling and logprobs
|
||||||
- Better handling of sporadic code_interpreter tool calls
|
- Better handling of sporadic code_interpreter tool calls
|
||||||
|
|
||||||
### Agent Evals
|
### Agent Evals
|
||||||
Better benchmarking and Agent performance assessment
|
Better benchmarking and Agent performance assessment
|
||||||
- Renamed eval API /eval-task to /benchmarks
|
- Renamed eval API /eval-task to /benchmarks
|
||||||
- Improved documentation and notebooks for RAG and evals
|
- Improved documentation and notebooks for RAG and evals
|
||||||
|
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
Improved production readiness
|
Improved production readiness
|
||||||
- Added usage metrics collection for chat completions
|
- Added usage metrics collection for chat completions
|
||||||
- CLI improvements for provider information
|
- CLI improvements for provider information
|
||||||
- Improved error handling and system reliability
|
- Improved error handling and system reliability
|
||||||
- Better model endpoint handling and accessibility
|
- Better model endpoint handling and accessibility
|
||||||
- Improved signal handling on distro server
|
- Improved signal handling on distro server
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
Infrastructure and code quality improvements
|
Infrastructure and code quality improvements
|
||||||
- Faster text-based chat completion tests
|
- Faster text-based chat completion tests
|
||||||
- Improved testing for non-streaming agent apis
|
- Improved testing for non-streaming agent apis
|
||||||
- Standardized import formatting with ruff linter
|
- Standardized import formatting with ruff linter
|
||||||
- Added conventional commits standard
|
- Added conventional commits standard
|
||||||
- Fixed documentation parsing issues
|
- Fixed documentation parsing issues
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.2
|
# v0.1.2
|
||||||
Published on: 2025-02-07T22:06:49Z
|
Published on: 2025-02-07T22:06:49Z
|
||||||
|
|
||||||
# TL;DR
|
# TL;DR
|
||||||
- Several stabilizations to development flows after the switch to `uv`
|
- Several stabilizations to development flows after the switch to `uv`
|
||||||
- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops)
|
- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops)
|
||||||
- Added automated rebuilds for ReadTheDocs
|
- Added automated rebuilds for ReadTheDocs
|
||||||
- Llama Stack server supports HTTPS
|
- Llama Stack server supports HTTPS
|
||||||
- Added system prompt overrides support
|
- Added system prompt overrides support
|
||||||
- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan )
|
- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan )
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.1
|
# v0.1.1
|
||||||
Published on: 2025-02-02T02:29:24Z
|
Published on: 2025-02-02T02:29:24Z
|
||||||
|
|
||||||
A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements.
|
A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements.
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.0
|
# v0.1.0
|
||||||
Published on: 2025-01-24T17:47:47Z
|
Published on: 2025-01-24T17:47:47Z
|
||||||
|
|
||||||
We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions.
|
We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions.
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open.
|
GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open.
|
||||||
|
|
||||||
Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety.
|
Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety.
|
||||||
|
|
||||||
With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience.
|
With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience.
|
||||||
|
|
||||||
## Release
|
## Release
|
||||||
After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements.
|
After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements.
|
||||||
|
|
||||||
There are example standalone apps in llama-stack-apps.
|
There are example standalone apps in llama-stack-apps.
|
||||||
|
|
||||||
|
|
||||||
## Key Features of this release
|
## Key Features of this release
|
||||||
|
|
||||||
- **Unified API Layer**
|
- **Unified API Layer**
|
||||||
- Inference: Run LLM models
|
- Inference: Run LLM models
|
||||||
- RAG: Store and retrieve knowledge for RAG
|
- RAG: Store and retrieve knowledge for RAG
|
||||||
- Agents: Build multi-step agentic workflows
|
- Agents: Build multi-step agentic workflows
|
||||||
- Tools: Register tools that can be called by the agent
|
- Tools: Register tools that can be called by the agent
|
||||||
- Safety: Apply content filtering and safety policies
|
- Safety: Apply content filtering and safety policies
|
||||||
- Evaluation: Test model and agent quality
|
- Evaluation: Test model and agent quality
|
||||||
- Telemetry: Collect and analyze usage data and complex agentic traces
|
- Telemetry: Collect and analyze usage data and complex agentic traces
|
||||||
- Post Training ( Coming Soon ): Fine tune models for specific use cases
|
- Post Training ( Coming Soon ): Fine tune models for specific use cases
|
||||||
|
|
||||||
- **Rich Provider Ecosystem**
|
- **Rich Provider Ecosystem**
|
||||||
- Local Development: Meta's Reference, Ollama
|
- Local Development: Meta's Reference, Ollama
|
||||||
- Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras
|
- Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras
|
||||||
- On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI
|
- On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI
|
||||||
- On-device: iOS and Android support
|
- On-device: iOS and Android support
|
||||||
|
|
||||||
- **Built for Production**
|
- **Built for Production**
|
||||||
- Pre-packaged distributions for common deployment scenarios
|
- Pre-packaged distributions for common deployment scenarios
|
||||||
- Backwards compatibility across model versions
|
- Backwards compatibility across model versions
|
||||||
- Comprehensive evaluation capabilities
|
- Comprehensive evaluation capabilities
|
||||||
- Full observability and monitoring
|
- Full observability and monitoring
|
||||||
|
|
||||||
- **Multiple developer interfaces**
|
- **Multiple developer interfaces**
|
||||||
- CLI: Command line interface
|
- CLI: Command line interface
|
||||||
- Python SDK
|
- Python SDK
|
||||||
- Swift iOS SDK
|
- Swift iOS SDK
|
||||||
- Kotlin Android SDK
|
- Kotlin Android SDK
|
||||||
|
|
||||||
- **Sample llama stack applications**
|
- **Sample llama stack applications**
|
||||||
- Python
|
- Python
|
||||||
- iOS
|
- iOS
|
||||||
- Android
|
- Android
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -365,8 +365,8 @@ Published on: 2025-01-22T22:24:01Z
|
||||||
# v0.0.63
|
# v0.0.63
|
||||||
Published on: 2024-12-18T07:17:43Z
|
Published on: 2024-12-18T07:17:43Z
|
||||||
|
|
||||||
A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially.
|
A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially.
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -402,39 +402,39 @@ Published on: 2024-11-22T00:36:09Z
|
||||||
# v0.0.53
|
# v0.0.53
|
||||||
Published on: 2024-11-20T22:18:00Z
|
Published on: 2024-11-20T22:18:00Z
|
||||||
|
|
||||||
🚀 Initial Release Notes for Llama Stack!
|
🚀 Initial Release Notes for Llama Stack!
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
|
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
|
||||||
- Persistence for registered objects with distribution
|
- Persistence for registered objects with distribution
|
||||||
- Ability to persist memory banks created for FAISS
|
- Ability to persist memory banks created for FAISS
|
||||||
- PostgreSQL KVStore implementation
|
- PostgreSQL KVStore implementation
|
||||||
- Environment variable placeholder support in run.yaml files
|
- Environment variable placeholder support in run.yaml files
|
||||||
- Comprehensive Zero-to-Hero notebooks and quickstart guides
|
- Comprehensive Zero-to-Hero notebooks and quickstart guides
|
||||||
- Support for quantized models in Ollama
|
- Support for quantized models in Ollama
|
||||||
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
|
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
|
||||||
- Bedrock distribution with safety shields support
|
- Bedrock distribution with safety shields support
|
||||||
- Evals API with task registration and scoring functions
|
- Evals API with task registration and scoring functions
|
||||||
- MMLU and SimpleQA benchmark scoring functions
|
- MMLU and SimpleQA benchmark scoring functions
|
||||||
- Huggingface dataset provider integration for benchmarks
|
- Huggingface dataset provider integration for benchmarks
|
||||||
- Support for custom dataset registration from local paths
|
- Support for custom dataset registration from local paths
|
||||||
- Benchmark evaluation CLI tools with visualization tables
|
- Benchmark evaluation CLI tools with visualization tables
|
||||||
- RAG evaluation scoring functions and metrics
|
- RAG evaluation scoring functions and metrics
|
||||||
- Local persistence for datasets and eval tasks
|
- Local persistence for datasets and eval tasks
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
|
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
|
||||||
- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`)
|
- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`)
|
||||||
- Updated API signatures for dataset and eval task registration
|
- Updated API signatures for dataset and eval task registration
|
||||||
- Restructured folder organization for providers
|
- Restructured folder organization for providers
|
||||||
- Enhanced Docker build configuration
|
- Enhanced Docker build configuration
|
||||||
- Added version prefixing for REST API routes
|
- Added version prefixing for REST API routes
|
||||||
- Enhanced evaluation task registration workflow
|
- Enhanced evaluation task registration workflow
|
||||||
- Improved benchmark evaluation output formatting
|
- Improved benchmark evaluation output formatting
|
||||||
- Restructured evals folder organization for better modularity
|
- Restructured evals folder organization for better modularity
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- `llama stack configure` command
|
- `llama stack configure` command
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
@ -141,11 +141,18 @@ uv sync
|
||||||
|
|
||||||
## Coding Style
|
## Coding Style
|
||||||
|
|
||||||
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
* Comments should provide meaningful insights into the code. Avoid filler comments that simply
|
||||||
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
|
describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||||
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
|
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code
|
||||||
|
rather than explain what the next line of code does.
|
||||||
|
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like
|
||||||
|
`Exception`.
|
||||||
* Error messages should be prefixed with "Failed to ..."
|
* Error messages should be prefixed with "Failed to ..."
|
||||||
* 4 spaces for indentation rather than tabs
|
* 4 spaces for indentation rather than tab
|
||||||
|
* When using `# noqa` to suppress a style or linter warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
|
||||||
## Common Tasks
|
## Common Tasks
|
||||||
|
|
||||||
|
|
6
docs/_static/css/my_theme.css
vendored
6
docs/_static/css/my_theme.css
vendored
|
@ -27,3 +27,9 @@ pre {
|
||||||
white-space: pre-wrap !important;
|
white-space: pre-wrap !important;
|
||||||
word-break: break-all;
|
word-break: break-all;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .mermaid {
|
||||||
|
background-color: #f4f4f6 !important;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -6462,6 +6462,9 @@
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -4506,6 +4506,8 @@ components:
|
||||||
type: boolean
|
type: boolean
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
tools:
|
tools:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
@ -1,35 +1,35 @@
|
||||||
@ECHO OFF
|
@ECHO OFF
|
||||||
|
|
||||||
pushd %~dp0
|
pushd %~dp0
|
||||||
|
|
||||||
REM Command file for Sphinx documentation
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
if "%SPHINXBUILD%" == "" (
|
if "%SPHINXBUILD%" == "" (
|
||||||
set SPHINXBUILD=sphinx-build
|
set SPHINXBUILD=sphinx-build
|
||||||
)
|
)
|
||||||
set SOURCEDIR=.
|
set SOURCEDIR=.
|
||||||
set BUILDDIR=_build
|
set BUILDDIR=_build
|
||||||
|
|
||||||
%SPHINXBUILD% >NUL 2>NUL
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
if errorlevel 9009 (
|
if errorlevel 9009 (
|
||||||
echo.
|
echo.
|
||||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
echo.may add the Sphinx directory to PATH.
|
echo.may add the Sphinx directory to PATH.
|
||||||
echo.
|
echo.
|
||||||
echo.If you don't have Sphinx installed, grab it from
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
echo.https://www.sphinx-doc.org/
|
echo.https://www.sphinx-doc.org/
|
||||||
exit /b 1
|
exit /b 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if "%1" == "" goto help
|
if "%1" == "" goto help
|
||||||
|
|
||||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
goto end
|
goto end
|
||||||
|
|
||||||
:help
|
:help
|
||||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
:end
|
:end
|
||||||
popd
|
popd
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
@ -189,7 +190,7 @@ class ContentBuilder:
|
||||||
else:
|
else:
|
||||||
return "application/json"
|
return "application/json"
|
||||||
|
|
||||||
if typing.get_origin(payload_type) is typing.Union:
|
if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
|
||||||
media_types = []
|
media_types = []
|
||||||
item_types = []
|
item_types = []
|
||||||
for x in typing.get_args(payload_type):
|
for x in typing.get_args(payload_type):
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
sphinx==8.1.3
|
|
||||||
myst-parser
|
|
||||||
linkify
|
linkify
|
||||||
|
myst-parser
|
||||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||||
sphinx-rtd-theme>=1.0.0
|
sphinx==8.1.3
|
||||||
sphinx_autobuild
|
|
||||||
sphinx-copybutton
|
sphinx-copybutton
|
||||||
sphinx-design
|
sphinx-design
|
||||||
sphinx-pdj-theme
|
sphinx-pdj-theme
|
||||||
sphinx_rtd_dark_mode
|
sphinx-rtd-theme>=1.0.0
|
||||||
sphinx-tabs
|
sphinx-tabs
|
||||||
|
sphinx_autobuild
|
||||||
|
sphinx_rtd_dark_mode
|
||||||
|
sphinxcontrib-mermaid
|
||||||
sphinxcontrib-openapi
|
sphinxcontrib-openapi
|
||||||
sphinxcontrib-redoc
|
sphinxcontrib-redoc
|
||||||
sphinxcontrib-mermaid
|
|
||||||
sphinxcontrib-video
|
sphinxcontrib-video
|
||||||
tomli
|
tomli
|
||||||
|
|
|
@ -43,27 +43,6 @@ The tool requires an API key which can be provided either in the configuration o
|
||||||
|
|
||||||
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
||||||
|
|
||||||
#### Code Interpreter
|
|
||||||
|
|
||||||
The Code Interpreter allows execution of Python code within a controlled environment.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Register Code Interpreter tool group
|
|
||||||
client.toolgroups.register(
|
|
||||||
toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Secure execution environment using `bwrap` sandboxing
|
|
||||||
- Matplotlib support for generating plots
|
|
||||||
- Disabled dangerous system operations
|
|
||||||
- Configurable execution timeouts
|
|
||||||
|
|
||||||
> ⚠️ Important: The code interpreter tool can operate in a controlled environment locally or on Podman containers. To ensure proper functionality in containerized environments:
|
|
||||||
> - The container requires privileged access (e.g., --privileged).
|
|
||||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
|
||||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
|
||||||
|
|
||||||
#### WolframAlpha
|
#### WolframAlpha
|
||||||
|
|
||||||
|
@ -102,7 +81,7 @@ Features:
|
||||||
- Context retrieval with token limits
|
- Context retrieval with token limits
|
||||||
|
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP) Tools
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss` |
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
| safety | `remote::bedrock` |
|
| safety | `remote::bedrock` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss` |
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
| inference | `remote::sambanova` |
|
| inference | `remote::sambanova` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,9 @@ Here's a list of known external providers that you can use with Llama Stack:
|
||||||
| Name | Description | API | Type | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|-------------|-----|------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
|
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
||||||
|
|
||||||
### Remote Provider Specification
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
|
107
docs/source/providers/vector_io/milvus.md
Normal file
107
docs/source/providers/vector_io/milvus.md
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Milvus
|
||||||
|
|
||||||
|
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly within a Milvus database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Milvus.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Milvus using pymilvus:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
In Llama Stack, Milvus can be configured in two ways:
|
||||||
|
- **Inline (Local) Configuration** - Uses Milvus-Lite for local storage
|
||||||
|
- **Remote Configuration** - Connects to a remote Milvus server
|
||||||
|
|
||||||
|
### Inline (Local) Configuration
|
||||||
|
|
||||||
|
The simplest method is local configuration, which requires setting `db_path`, a path for locally storing Milvus-Lite files:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: inline::milvus
|
||||||
|
config:
|
||||||
|
db_path: ~/.llama/distributions/together/milvus_store.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Remote Configuration
|
||||||
|
|
||||||
|
Remote configuration is suitable for larger data storage requirements:
|
||||||
|
|
||||||
|
#### Standard Remote Connection
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "http://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### TLS-Enabled Remote Connection (One-way TLS)
|
||||||
|
|
||||||
|
For connections to Milvus instances with one-way TLS enabled:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "https://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
secure: True
|
||||||
|
server_pem_path: "/path/to/server.pem"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Mutual TLS (mTLS) Remote Connection
|
||||||
|
|
||||||
|
For connections to Milvus instances with mutual TLS (mTLS) enabled:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "https://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
secure: True
|
||||||
|
ca_pem_path: "/path/to/ca.pem"
|
||||||
|
client_pem_path: "/path/to/client.pem"
|
||||||
|
client_key_path: "/path/to/client.key"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Key Parameters for TLS Configuration
|
||||||
|
|
||||||
|
- **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`.
|
||||||
|
- **`server_pem_path`**: Path to the **server certificate** for verifying the server’s identity (used in one-way TLS).
|
||||||
|
- **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS).
|
||||||
|
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
||||||
|
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
||||||
|
|
||||||
|
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
|
@ -1,31 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# Milvus
|
|
||||||
|
|
||||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
|
||||||
allows you to store and query vectors directly within a Milvus database.
|
|
||||||
That means you're not limited to storing vectors in memory or in a separate service.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Easy to use
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use Milvus in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Install the necessary dependencies.
|
|
||||||
2. Configure your Llama Stack project to use Milvus.
|
|
||||||
3. Start storing and querying vectors.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install pymilvus
|
|
||||||
```
|
|
||||||
## Documentation
|
|
||||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
|
|
@ -86,11 +86,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
llama stack build --template ollama --image-type conda
|
llama stack build --template ollama --image-type conda
|
||||||
```
|
```
|
||||||
**Expected Output:**
|
**Expected Output:**
|
||||||
```
|
```bash
|
||||||
...
|
...
|
||||||
Build Successful! Next steps:
|
Build Successful!
|
||||||
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
You can find the newly-built template here: ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
You can run the new Llama Stack Distro via: llama stack run ~/.llama/distributions/ollama/ollama-run.yaml --image-type conda
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Set the ENV variables by exporting them to the terminal**:
|
3. **Set the ENV variables by exporting them to the terminal**:
|
||||||
|
|
107
install.sh
107
install.sh
|
@ -16,61 +16,120 @@ WAIT_TIMEOUT=300
|
||||||
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||||
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
||||||
|
|
||||||
|
wait_for_service() {
|
||||||
|
local url="$1"
|
||||||
|
local pattern="$2"
|
||||||
|
local timeout="$3"
|
||||||
|
local name="$4"
|
||||||
|
local start ts
|
||||||
|
log "⏳ Waiting for ${name}…"
|
||||||
|
start=$(date +%s)
|
||||||
|
while true; do
|
||||||
|
if curl --retry 5 --retry-delay 1 --retry-max-time "$timeout" --retry-all-errors --silent --fail "$url" 2>/dev/null | grep -q "$pattern"; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
ts=$(date +%s)
|
||||||
|
if (( ts - start >= timeout )); then
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
printf '.'
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
if command -v docker &> /dev/null; then
|
if command -v docker &> /dev/null; then
|
||||||
ENGINE="docker"
|
ENGINE="docker"
|
||||||
HOST_DNS="host.docker.internal"
|
|
||||||
elif command -v podman &> /dev/null; then
|
elif command -v podman &> /dev/null; then
|
||||||
ENGINE="podman"
|
ENGINE="podman"
|
||||||
HOST_DNS="host.containers.internal"
|
|
||||||
else
|
else
|
||||||
die "Docker or Podman is required. Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
die "Docker or Podman is required. Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Explicitly set the platform for the host architecture
|
||||||
|
HOST_ARCH="$(uname -m)"
|
||||||
|
if [ "$HOST_ARCH" = "arm64" ]; then
|
||||||
|
if [ "$ENGINE" = "docker" ]; then
|
||||||
|
PLATFORM_OPTS=( --platform linux/amd64 )
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=( --os linux --arch amd64 )
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=()
|
||||||
|
fi
|
||||||
|
|
||||||
|
# macOS + Podman: ensure VM is running before we try to launch containers
|
||||||
|
# If you need GPU passthrough under Podman on macOS, init the VM with libkrun:
|
||||||
|
# CONTAINERS_MACHINE_PROVIDER=libkrun podman machine init
|
||||||
|
if [ "$ENGINE" = "podman" ] && [ "$(uname -s)" = "Darwin" ]; then
|
||||||
|
if ! podman info &>/dev/null; then
|
||||||
|
log "⌛️ Initializing Podman VM…"
|
||||||
|
podman machine init &>/dev/null || true
|
||||||
|
podman machine start &>/dev/null || true
|
||||||
|
|
||||||
|
log "⌛️ Waiting for Podman API…"
|
||||||
|
until podman info &>/dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
log "✅ Podman VM is up"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
# Clean up any leftovers from earlier runs
|
# Clean up any leftovers from earlier runs
|
||||||
for name in ollama-server llama-stack; do
|
for name in ollama-server llama-stack; do
|
||||||
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
||||||
if [ -n "$ids" ]; then
|
if [ -n "$ids" ]; then
|
||||||
log "⚠️ Found existing container(s) for '${name}', removing..."
|
log "⚠️ Found existing container(s) for '${name}', removing…"
|
||||||
$ENGINE rm -f "$ids"
|
$ENGINE rm -f "$ids" > /dev/null 2>&1
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 0. Create a shared network
|
||||||
|
###############################################################################
|
||||||
|
if ! $ENGINE network inspect llama-net >/dev/null 2>&1; then
|
||||||
|
log "🌐 Creating network…"
|
||||||
|
$ENGINE network create llama-net >/dev/null 2>&1
|
||||||
|
fi
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# 1. Ollama
|
# 1. Ollama
|
||||||
###############################################################################
|
###############################################################################
|
||||||
log "🦙 Starting Ollama…"
|
log "🦙 Starting Ollama…"
|
||||||
$ENGINE run -d --name ollama-server \
|
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||||
-p "${OLLAMA_PORT}:11434" \
|
--network llama-net \
|
||||||
|
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
|
||||||
ollama/ollama > /dev/null 2>&1
|
ollama/ollama > /dev/null 2>&1
|
||||||
|
|
||||||
log "⏳ Waiting for Ollama daemon…"
|
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
|
||||||
if ! timeout "$WAIT_TIMEOUT" bash -c \
|
|
||||||
"until curl -fsS http://localhost:${OLLAMA_PORT}/ 2>/dev/null | grep -q 'Ollama'; do sleep 1; done"; then
|
|
||||||
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
$ENGINE logs ollama-server --tail=200
|
$ENGINE logs --tail 200 ollama-server
|
||||||
die "Ollama startup failed"
|
die "Ollama startup failed"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}..."
|
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}…"
|
||||||
$ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1
|
if ! $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1; then
|
||||||
|
log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 ollama-server
|
||||||
|
die "Model pull failed"
|
||||||
|
fi
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# 2. Llama‑Stack
|
# 2. Llama‑Stack
|
||||||
###############################################################################
|
###############################################################################
|
||||||
log "🦙📦 Starting Llama‑Stack…"
|
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
$ENGINE run -d --name llama-stack \
|
--network llama-net \
|
||||||
-p "${PORT}:${PORT}" \
|
-p "${PORT}:${PORT}" \
|
||||||
--add-host="${HOST_DNS}:host-gateway" \
|
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||||
"${SERVER_IMAGE}" \
|
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||||
--port "${PORT}" \
|
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" )
|
||||||
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
|
||||||
--env OLLAMA_URL="http://${HOST_DNS}:${OLLAMA_PORT}" > /dev/null 2>&1
|
|
||||||
|
|
||||||
log "⏳ Waiting for Llama-Stack API…"
|
log "🦙 Starting Llama‑Stack…"
|
||||||
if ! timeout "$WAIT_TIMEOUT" bash -c \
|
$ENGINE "${cmd[@]}" > /dev/null 2>&1
|
||||||
"until curl -fsS http://localhost:${PORT}/v1/health 2>/dev/null | grep -q 'OK'; do sleep 1; done"; then
|
|
||||||
|
if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama-Stack API"; then
|
||||||
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
$ENGINE logs llama-stack --tail=200
|
$ENGINE logs --tail 200 llama-stack
|
||||||
die "Llama-Stack startup failed"
|
die "Llama-Stack startup failed"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -4,20 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
class StepType(Enum):
|
||||||
|
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||||
tool_calls: List[ToolCall]
|
tool_calls: list[ToolCall]
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||||
violation: Optional[SafetyViolation]
|
violation: SafetyViolation | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
Union[
|
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||||
InferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ShieldCallStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
],
|
|
||||||
Field(discriminator="step_type"),
|
Field(discriminator="step_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -166,18 +151,13 @@ class Turn(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
input_messages: List[
|
input_messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
steps: list[Step]
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
steps: List[Step]
|
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -186,34 +166,31 @@ class Session(BaseModel):
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
turns: List[Turn]
|
turns: list[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
args: Dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = Union[
|
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||||
str,
|
|
||||||
AgentToolGroupWithArgs,
|
|
||||||
]
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: list[str] | None = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: list[str] | None = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
|
||||||
max_infer_iters: Optional[int] = 10
|
max_infer_iters: int | None = 10
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
def model_post_init(self, __context):
|
||||||
if self.tool_config:
|
if self.tool_config:
|
||||||
|
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
enable_session_persistence: Optional[bool] = False
|
enable_session_persistence: bool | None = False
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -257,16 +234,16 @@ class Agent(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentsResponse(BaseModel):
|
class ListAgentsResponse(BaseModel):
|
||||||
data: List[Agent]
|
data: list[Agent]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
class ListAgentSessionsResponse(BaseModel):
|
||||||
data: List[Session]
|
data: list[Session]
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(Enum):
|
||||||
|
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Union[
|
AgentTurnResponseStepStartPayload
|
||||||
AgentTurnResponseStepStartPayload,
|
| AgentTurnResponseStepProgressPayload
|
||||||
AgentTurnResponseStepProgressPayload,
|
| AgentTurnResponseStepCompletePayload
|
||||||
AgentTurnResponseStepCompletePayload,
|
| AgentTurnResponseTurnStartPayload
|
||||||
AgentTurnResponseTurnStartPayload,
|
| AgentTurnResponseTurnCompletePayload
|
||||||
AgentTurnResponseTurnCompletePayload,
|
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
# TODO: figure out how we can simplify this and make why
|
# TODO: figure out how we can simplify this and make why
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
# execution from outside the system)
|
# execution from outside the system)
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
documents: Optional[List[Document]] = None
|
documents: list[Document] | None = None
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
toolgroups: list[AgentToolGroup] | None = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: ToolConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,17 +399,12 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
Union[
|
stream: bool | None = False,
|
||||||
UserMessage,
|
documents: list[Document] | None = None,
|
||||||
ToolResponseMessage,
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
]
|
tool_config: ToolConfig | None = None,
|
||||||
],
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
stream: Optional[bool] = False,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
||||||
tool_config: Optional[ToolConfig] = None,
|
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
|
||||||
"""Create a new turn for an agent.
|
"""Create a new turn for an agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
:param agent_id: The ID of the agent to create the turn for.
|
||||||
|
@ -463,9 +428,9 @@ class Agents(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponse],
|
tool_responses: list[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
"""Resume an agent turn with executed tool call responses.
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
@ -538,7 +503,7 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Retrieve an agent session by its ID.
|
"""Retrieve an agent session by its ID.
|
||||||
|
|
||||||
|
@ -623,13 +588,14 @@ class Agents(Protocol):
|
||||||
@webmethod(route="/openai/v1/responses", method="POST")
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
input: str | list[OpenAIResponseInputMessage],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: str | None = None,
|
||||||
store: Optional[bool] = True,
|
store: bool | None = True,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
temperature: float | None = None,
|
||||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
"""Create a new OpenAI response.
|
"""Create a new OpenAI response.
|
||||||
|
|
||||||
:param input: Input message(s) to create the response.
|
:param input: Input message(s) to create the response.
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutputMessageContent = Annotated[
|
OpenAIResponseOutputMessageContent = Annotated[
|
||||||
Union[OpenAIResponseOutputMessageContentOutputText,],
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||||
|
@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessage(BaseModel):
|
class OpenAIResponseOutputMessage(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
content: List[OpenAIResponseOutputMessageContent]
|
content: list[OpenAIResponseOutputMessageContent]
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
status: str
|
status: str
|
||||||
type: Literal["message"] = "message"
|
type: Literal["message"] = "message"
|
||||||
|
@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
Union[
|
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
OpenAIResponseOutputMessage,
|
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObject(BaseModel):
|
class OpenAIResponseObject(BaseModel):
|
||||||
created_at: int
|
created_at: int
|
||||||
error: Optional[OpenAIResponseError] = None
|
error: OpenAIResponseError | None = None
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
object: Literal["response"] = "response"
|
object: Literal["response"] = "response"
|
||||||
output: List[OpenAIResponseOutput]
|
output: list[OpenAIResponseOutput]
|
||||||
parallel_tool_calls: bool = False
|
parallel_tool_calls: bool = False
|
||||||
previous_response_id: Optional[str] = None
|
previous_response_id: str | None = None
|
||||||
status: str
|
status: str
|
||||||
temperature: Optional[float] = None
|
temperature: float | None = None
|
||||||
top_p: Optional[float] = None
|
top_p: float | None = None
|
||||||
truncation: Optional[str] = None
|
truncation: str | None = None
|
||||||
user: Optional[str] = None
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
Union[
|
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||||
type: Literal["input_image"] = "input_image"
|
type: Literal["input_image"] = "input_image"
|
||||||
# TODO: handle file_id
|
# TODO: handle file_id
|
||||||
image_url: Optional[str] = None
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# TODO: handle file content types
|
# TODO: handle file content types
|
||||||
OpenAIResponseInputMessageContent = Annotated[
|
OpenAIResponseInputMessageContent = Annotated[
|
||||||
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
|
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||||
|
@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputMessage(BaseModel):
|
class OpenAIResponseInputMessage(BaseModel):
|
||||||
content: Union[str, List[OpenAIResponseInputMessageContent]]
|
content: str | list[OpenAIResponseInputMessageContent]
|
||||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
type: Optional[Literal["message"]] = "message"
|
type: Literal["message"] | None = "message"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||||
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||||
# TODO: actually use search_context_size somewhere...
|
# TODO: actually use search_context_size somewhere...
|
||||||
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
|
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||||
# TODO: add user_location
|
# TODO: add user_location
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
Union[OpenAIResponseInputToolWebSearch,],
|
OpenAIResponseInputToolWebSearch,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -34,22 +34,22 @@ class BatchInference(Protocol):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
class CommonBenchmarkFields(BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: list[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Metadata for this evaluation task",
|
description="Metadata for this evaluation task",
|
||||||
)
|
)
|
||||||
|
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
|
|
||||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
benchmark_id: str
|
benchmark_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_benchmark_id: Optional[str] = None
|
provider_benchmark_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListBenchmarksResponse(BaseModel):
|
class ListBenchmarksResponse(BaseModel):
|
||||||
data: List[Benchmark]
|
data: list[Benchmark]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
|
||||||
:param data: base64 encoded image data as string
|
:param data: base64 encoded image data as string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: Optional[URL] = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
data: Optional[str] = Field(contentEncoding="base64", default=None)
|
data: str | None = Field(contentEncoding="base64", default=None)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = Annotated[
|
InterleavedContentItem = Annotated[
|
||||||
Union[ImageContentItem, TextContentItem],
|
ImageContentItem | TextContentItem,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||||
register_schema(InterleavedContent, name="InterleavedContent")
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
|
||||||
# you either send an in-progress tool call so the client can stream a long
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
# code generation or you send the final parsed tool call at the end of the
|
# code generation or you send the final parsed tool call at the end of the
|
||||||
# stream
|
# stream
|
||||||
tool_call: Union[str, ToolCall]
|
tool_call: str | ToolCall
|
||||||
parse_status: ToolCallParseStatus
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = Annotated[
|
ContentDelta = Annotated[
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
TextDelta | ImageDelta | ToolCallDelta,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ContentDelta, name="ContentDelta")
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
url: URL
|
url: URL
|
||||||
method: RestAPIMethod
|
method: RestAPIMethod
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
headers: Optional[Dict[str, Any]] = None
|
headers: dict[str, Any] | None = None
|
||||||
body: Optional[Dict[str, Any]] = None
|
body: dict[str, Any] | None = None
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
|
||||||
:param has_more: Whether there are more items available after this set
|
:param has_more: Whether there are more items available after this set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[Dict[str, Any]]
|
data: list[dict[str, Any]]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metrics: Optional[PostTrainingMetric] = None
|
training_metrics: PostTrainingMetric | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -73,18 +72,16 @@ class DialogType(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
Union[
|
StringType
|
||||||
StringType,
|
| NumberType
|
||||||
NumberType,
|
| BooleanType
|
||||||
BooleanType,
|
| ArrayType
|
||||||
ArrayType,
|
| ObjectType
|
||||||
ObjectType,
|
| JsonType
|
||||||
JsonType,
|
| UnionType
|
||||||
UnionType,
|
| ChatCompletionInputType
|
||||||
ChatCompletionInputType,
|
| CompletionInputType
|
||||||
CompletionInputType,
|
| AgentTurnInputType,
|
||||||
AgentTurnInputType,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ParamType, name="ParamType")
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Get a paginated list of rows from a dataset.
|
"""Get a paginated list of rows from a dataset.
|
||||||
|
|
||||||
|
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = Annotated[
|
DataSource = Annotated[
|
||||||
Union[URIDataSource, RowsDataSource],
|
URIDataSource | RowsDataSource,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(DataSource, name="DataSource")
|
register_schema(DataSource, name="DataSource")
|
||||||
|
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
purpose: DatasetPurpose
|
||||||
source: DataSource
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
class ListDatasetsResponse(BaseModel):
|
||||||
data: List[Dataset]
|
data: list[Dataset]
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
|
@ -131,8 +131,8 @@ class Datasets(Protocol):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -54,4 +53,4 @@ class Error(BaseModel):
|
||||||
status: int
|
status: int
|
||||||
title: str
|
title: str
|
||||||
detail: str
|
detail: str
|
||||||
instance: Optional[str] = None
|
instance: str | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
|
||||||
type: Literal["model"] = "model"
|
type: Literal["model"] = "model"
|
||||||
model: str
|
model: str
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
system_message: Optional[SystemMessage] = None
|
system_message: SystemMessage | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
num_examples: Optional[int] = Field(
|
num_examples: int | None = Field(
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
|
||||||
:param scores: The scores from the evaluation.
|
:param scores: The scores from the evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: List[Dict[str, Any]]
|
generations: list[dict[str, Any]]
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
scores: Dict[str, ScoringResult]
|
scores: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
|
@ -101,8 +100,8 @@ class Eval(Protocol):
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
"""Evaluate a list of rows on a benchmark.
|
"""Evaluate a list of rows on a benchmark.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[BucketResponse]
|
data: list[BucketResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[FileResponse]
|
data: list[FileResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -102,7 +102,7 @@ class Files(Protocol):
|
||||||
async def upload_content_to_session(
|
async def upload_content_to_session(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> Optional[FileResponse]:
|
) -> FileResponse | None:
|
||||||
"""
|
"""
|
||||||
Upload file content to an existing upload session.
|
Upload file content to an existing upload session.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
On the server, request body will have the raw bytes that are uploaded.
|
||||||
|
|
|
@ -4,21 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TopPSamplingStrategy(BaseModel):
|
class TopPSamplingStrategy(BaseModel):
|
||||||
type: Literal["top_p"] = "top_p"
|
type: Literal["top_p"] = "top_p"
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
temperature: float | None = Field(..., gt=0.0)
|
||||||
top_p: Optional[float] = 0.95
|
top_p: float | None = 0.95
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
SamplingStrategy = Annotated[
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
max_tokens: int | None = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: float | None = 1.0
|
||||||
stop: Optional[List[str]] = None
|
stop: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
|
||||||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: Optional[int] = 0
|
top_k: int | None = 0
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(Enum):
|
class QuantizationType(Enum):
|
||||||
|
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4_mixed"] = "int4_mixed"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
context: Optional[InterleavedContent] = None
|
context: InterleavedContent | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
Message = Annotated[
|
||||||
Union[
|
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||||
UserMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
CompletionMessage,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(Message, name="Message")
|
register_schema(Message, name="Message")
|
||||||
|
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
|
||||||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs_by_token: Dict[str, float]
|
logprobs_by_token: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
event_type: ChatCompletionResponseEventType
|
||||||
delta: ContentDelta
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(Enum):
|
||||||
|
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||||
json_schema: Dict[str, Any]
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||||
bnf: Dict[str, Any]
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = Annotated[
|
ResponseFormat = Annotated[
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ResponseFormat, name="ResponseFormat")
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
|
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
delta: str
|
delta: str
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
|
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: list[Message]
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
|
||||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: List[List[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIImageURL(BaseModel):
|
class OpenAIImageURL(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
detail: Optional[str] = None
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
Union[
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
arguments: Optional[str] = None
|
arguments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCall(BaseModel):
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
index: Optional[int] = None
|
index: int | None = None
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
content: OpenAIChatCompletionMessageContent | None = None
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIMessageParam = Annotated[
|
OpenAIMessageParam = Annotated[
|
||||||
Union[
|
OpenAIUserMessageParam
|
||||||
OpenAIUserMessageParam,
|
| OpenAISystemMessageParam
|
||||||
OpenAISystemMessageParam,
|
| OpenAIAssistantMessageParam
|
||||||
OpenAIAssistantMessageParam,
|
| OpenAIToolMessageParam
|
||||||
OpenAIToolMessageParam,
|
| OpenAIDeveloperMessageParam,
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIJSONSchema(TypedDict, total=False):
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
strict: Optional[bool] = None
|
strict: bool | None = None
|
||||||
|
|
||||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
# has one. And, we don't want to alias here because then have to handle
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
# that alias when converting to OpenAI params. So, to support schema,
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
# we use a TypedDict.
|
# we use a TypedDict.
|
||||||
schema: Optional[Dict[str, Any]] = None
|
schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
Union[
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
OpenAIResponseFormatText,
|
|
||||||
OpenAIResponseFormatJSONSchema,
|
|
||||||
OpenAIResponseFormatJSONObject,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: List[OpenAITopLogProb]
|
top_logprobs: list[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
||||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[List[OpenAITokenLogProb]] = None
|
content: list[OpenAITokenLogProb] | None = None
|
||||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
refusal: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
|
||||||
:param tool_calls: (Optional) The tool calls of the delta
|
:param tool_calls: (Optional) The tool calls of the delta
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
refusal: Optional[str] = None
|
refusal: str | None = None
|
||||||
role: Optional[str] = None
|
role: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
|
||||||
delta: OpenAIChoiceDelta
|
delta: OpenAIChoiceDelta
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
|
||||||
message: OpenAIMessageParam
|
message: OpenAIMessageParam
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChoice]
|
choices: list[OpenAIChoice]
|
||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChunkChoice]
|
choices: list[OpenAIChunkChoice]
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
||||||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_offset: Optional[List[int]] = None
|
text_offset: list[int] | None = None
|
||||||
token_logprobs: Optional[List[float]] = None
|
token_logprobs: list[float] | None = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: list[str] | None = None
|
||||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
top_logprobs: list[dict[str, float]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
text: str
|
text: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAICompletionChoice]
|
choices: list[OpenAICompletionChoice]
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
batch: List[CompletionResponse]
|
batch: list[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: List[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -843,11 +826,11 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
"""Generate a completion for the given content using the specified model.
|
"""Generate a completion for the given content using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -865,10 +848,10 @@ class Inference(Protocol):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
|
@ -876,16 +859,16 @@ class Inference(Protocol):
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -916,12 +899,12 @@ class Inference(Protocol):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
|
@ -929,10 +912,10 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
|
@ -950,25 +933,25 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
# Standard OpenAI completion parameters
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
@ -996,29 +979,29 @@ class Inference(Protocol):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
provider_types: List[str]
|
provider_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
class ListRoutesResponse(BaseModel):
|
||||||
data: List[RouteInfo]
|
data: list[RouteInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: str | None = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
model_type: ModelType | None = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
class ListModelsResponse(BaseModel):
|
||||||
data: List[Model]
|
data: list[Model]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
data: List[OpenAIModel]
|
data: list[OpenAIModel]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -95,10 +95,10 @@ class Models(Protocol):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
|
||||||
batch_size: int
|
batch_size: int
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
data_format: DatasetFormat
|
data_format: DatasetFormat
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: str | None = None
|
||||||
packed: Optional[bool] = False
|
packed: bool | None = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EfficiencyConfig(BaseModel):
|
class EfficiencyConfig(BaseModel):
|
||||||
enable_activation_checkpointing: Optional[bool] = False
|
enable_activation_checkpointing: bool | None = False
|
||||||
enable_activation_offloading: Optional[bool] = False
|
enable_activation_offloading: bool | None = False
|
||||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
memory_efficient_fsdp_wrap: bool | None = False
|
||||||
fsdp_cpu_offload: Optional[bool] = False
|
fsdp_cpu_offload: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int = 1
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int = 1
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: Optional[int] = 1
|
max_validation_steps: int | None = 1
|
||||||
data_config: Optional[DataConfig] = None
|
data_config: DataConfig | None = None
|
||||||
optimizer_config: Optional[OptimizerConfig] = None
|
optimizer_config: OptimizerConfig | None = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: EfficiencyConfig | None = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: str | None = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
type: Literal["LoRA"] = "LoRA"
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: list[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
use_dora: Optional[bool] = False
|
use_dora: bool | None = False
|
||||||
quantize_base: Optional[bool] = False
|
quantize_base: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
"""Stream of logs from a finetuning job."""
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
log_lines: List[str]
|
log_lines: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
|
||||||
# TODO: define these
|
# TODO: define these
|
||||||
hyperparam_search_config: Dict[str, Any]
|
hyperparam_search_config: dict[str, Any]
|
||||||
logger_config: Dict[str, Any]
|
logger_config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
|
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: datetime | None = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
resources_allocated: dict[str, Any] | None = None
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
data: List[PostTrainingJob]
|
data: list[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
"""Artifacts of a finetuning job."""
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
@ -175,14 +174,14 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
model: Optional[str] = Field(
|
model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Model descriptor for training if not in provider config`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
@ -192,8 +191,8 @@ class PostTraining(Protocol):
|
||||||
finetuned_model: str,
|
finetuned_model: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
|
||||||
api: str
|
api: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
data: List[ProviderInfo]
|
data: list[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
|
||||||
violation_level: ViolationLevel
|
violation_level: ViolationLevel
|
||||||
|
|
||||||
# what message should you convey to the user
|
# what message should you convey to the user
|
||||||
user_message: Optional[str] = None
|
user_message: str | None = None
|
||||||
|
|
||||||
# additional metadata (including specific violation codes) more for
|
# additional metadata (including specific violation codes) more for
|
||||||
# debugging, telemetry
|
# debugging, telemetry
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunShieldResponse(BaseModel):
|
class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: SafetyViolation | None = None
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
|
@ -52,6 +52,6 @@ class Safety(Protocol):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
|
||||||
:param aggregated_results: Map of metric name to aggregated value
|
:param aggregated_results: Map of metric name to aggregated value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
score_rows: List[ScoringResultRow]
|
score_rows: list[ScoringResultRow]
|
||||||
# aggregated metrics to value
|
# aggregated metrics to value
|
||||||
aggregated_results: Dict[str, Any]
|
aggregated_results: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoreBatchResponse(BaseModel):
|
class ScoreBatchResponse(BaseModel):
|
||||||
dataset_id: Optional[str] = None
|
dataset_id: str | None = None
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
|
@ -59,15 +59,15 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
"""Score a list of rows.
|
"""Score a list of rows.
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,14 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: str | None = None
|
||||||
judge_score_regexes: Optional[List[str]] = Field(
|
judge_score_regexes: list[str] | None = Field(
|
||||||
description="Regexes to extract the answer from generated response",
|
description="Regexes to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: list[str] | None = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BasicScoringFnParams(BaseModel):
|
class BasicScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = Annotated[
|
ScoringFnParams = Annotated[
|
||||||
Union[
|
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||||
LLMAsJudgeScoringFnParams,
|
|
||||||
RegexParserScoringFnParams,
|
|
||||||
BasicScoringFnParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
class CommonScoringFnFields(BaseModel):
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this definition",
|
description="Any additional metadata for this definition",
|
||||||
)
|
)
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field(
|
params: ScoringFnParams | None = Field(
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
scoring_fn_id: str
|
scoring_fn_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_scoring_fn_id: Optional[str] = None
|
provider_scoring_fn_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListScoringFunctionsResponse(BaseModel):
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
data: List[ScoringFn]
|
data: list[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
|
||||||
|
|
||||||
class ShieldInput(CommonShieldFields):
|
class ShieldInput(CommonShieldFields):
|
||||||
shield_id: str
|
shield_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_shield_id: Optional[str] = None
|
provider_shield_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListShieldsResponse(BaseModel):
|
class ListShieldsResponse(BaseModel):
|
||||||
data: List[Shield]
|
data: list[Shield]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -55,7 +55,7 @@ class Shields(Protocol):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield: ...
|
) -> Shield: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
class SyntheticDataGenerationRequest(BaseModel):
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
dialogs: List[Message]
|
dialogs: list[Message]
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
synthetic_data: List[Dict[str, Any]]
|
synthetic_data: list[dict[str, Any]]
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
statistics: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic-data-generation/generate")
|
@webmethod(route="/synthetic-data-generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
dialogs: List[Message],
|
dialogs: list[Message],
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> SyntheticDataGenerationResponse: ...
|
||||||
|
|
|
@ -7,18 +7,14 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import Primitive
|
from llama_stack.models.llama.datatypes import Primitive
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
@ -37,11 +33,11 @@ class SpanStatus(Enum):
|
||||||
class Span(BaseModel):
|
class Span(BaseModel):
|
||||||
span_id: str
|
span_id: str
|
||||||
trace_id: str
|
trace_id: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
if self.attributes is None:
|
if self.attributes is None:
|
||||||
|
@ -54,7 +50,7 @@ class Trace(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
root_span_id: str
|
root_span_id: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
span_id: str
|
span_id: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
|
||||||
class MetricEvent(EventCommon):
|
class MetricEvent(EventCommon):
|
||||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||||
metric: str # this would be an enum
|
metric: str # this would be an enum
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: str
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetricInResponse(BaseModel):
|
class MetricInResponse(BaseModel):
|
||||||
metric: str
|
metric: str
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: Optional[str] = None
|
unit: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# This is a short term solution to allow inference API to return metrics
|
# This is a short term solution to allow inference API to return metrics
|
||||||
|
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MetricResponseMixin(BaseModel):
|
class MetricResponseMixin(BaseModel):
|
||||||
metrics: Optional[List[MetricInResponse]] = None
|
metrics: list[MetricInResponse] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
|
||||||
class SpanStartPayload(BaseModel):
|
class SpanStartPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||||
name: str
|
name: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
StructuredLogPayload = Annotated[
|
StructuredLogPayload = Annotated[
|
||||||
Union[
|
SpanStartPayload | SpanEndPayload,
|
||||||
SpanStartPayload,
|
|
||||||
SpanEndPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
|
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
|
||||||
|
|
||||||
|
|
||||||
Event = Annotated[
|
Event = Annotated[
|
||||||
Union[
|
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||||
UnstructuredLogEvent,
|
|
||||||
MetricEvent,
|
|
||||||
StructuredLogEvent,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(Event, name="Event")
|
register_schema(Event, name="Event")
|
||||||
|
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanWithStatus(Span):
|
class SpanWithStatus(Span):
|
||||||
status: Optional[SpanStatus] = None
|
status: SpanStatus | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class QueryTracesResponse(BaseModel):
|
class QueryTracesResponse(BaseModel):
|
||||||
data: List[Trace]
|
data: list[Trace]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpansResponse(BaseModel):
|
class QuerySpansResponse(BaseModel):
|
||||||
data: List[Span]
|
data: list[Span]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpanTreeResponse(BaseModel):
|
class QuerySpanTreeResponse(BaseModel):
|
||||||
data: Dict[str, SpanWithStatus]
|
data: dict[str, SpanWithStatus]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -222,10 +211,10 @@ class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/traces", method="POST")
|
@webmethod(route="/telemetry/traces", method="POST")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
attribute_filters: list[QueryCondition] | None = None,
|
||||||
limit: Optional[int] = 100,
|
limit: int | None = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: int | None = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: list[str] | None = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
|
@ -238,23 +227,23 @@ class Telemetry(Protocol):
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: Optional[List[str]] = None,
|
attributes_to_return: list[str] | None = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpanTreeResponse: ...
|
) -> QuerySpanTreeResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans", method="POST")
|
@webmethod(route="/telemetry/spans", method="POST")
|
||||||
async def query_spans(
|
async def query_spans(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_return: List[str],
|
attributes_to_return: list[str],
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpansResponse: ...
|
) -> QuerySpansResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||||
async def save_spans_to_dataset(
|
async def save_spans_to_dataset(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_save: List[str],
|
attributes_to_save: list[str],
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
RAGQueryGeneratorConfig = Annotated[
|
RAGQueryGeneratorConfig = Annotated[
|
||||||
Union[
|
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||||
DefaultRAGQueryGeneratorConfig,
|
|
||||||
LLMRAGQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
|
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
|
||||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: list[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
"""Query the RAG system for context; typically invoked by the agent"""
|
"""Query the RAG system for context; typically invoked by the agent"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
|
||||||
parameter_type: str
|
parameter_type: str
|
||||||
description: str
|
description: str
|
||||||
required: bool = Field(default=True)
|
required: bool = Field(default=True)
|
||||||
default: Optional[Any] = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -40,39 +40,39 @@ class Tool(Resource):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
parameters: Optional[List[ToolParameter]] = None
|
parameters: list[ToolParameter] | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroupInput(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolInvocationResult(BaseModel):
|
class ToolInvocationResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
error_message: Optional[str] = None
|
error_message: str | None = None
|
||||||
error_code: Optional[int] = None
|
error_code: int | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
@ -81,11 +81,11 @@ class ToolStore(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
data: List[ToolGroup]
|
data: list[ToolGroup]
|
||||||
|
|
||||||
|
|
||||||
class ListToolsResponse(BaseModel):
|
class ListToolsResponse(BaseModel):
|
||||||
data: List[Tool]
|
data: list[Tool]
|
||||||
|
|
||||||
|
|
||||||
class ListToolDefsResponse(BaseModel):
|
class ListToolDefsResponse(BaseModel):
|
||||||
|
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools", method="GET")
|
@webmethod(route="/tools", method="GET")
|
||||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse: ...
|
) -> ListToolDefsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
|
||||||
vector_db_id: str
|
vector_db_id: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_dimension: int
|
embedding_dimension: int
|
||||||
provider_vector_db_id: Optional[str] = None
|
provider_vector_db_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListVectorDBsResponse(BaseModel):
|
class ListVectorDBsResponse(BaseModel):
|
||||||
data: List[VectorDB]
|
data: list[VectorDB]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QueryChunksResponse(BaseModel):
|
class QueryChunksResponse(BaseModel):
|
||||||
chunks: List[Chunk]
|
chunks: list[Chunk]
|
||||||
scores: List[float]
|
scores: list[float]
|
||||||
|
|
||||||
|
|
||||||
class VectorDBStore(Protocol):
|
class VectorDBStore(Protocol):
|
||||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -44,8 +44,8 @@ class VectorIO(Protocol):
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: List[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-io/query", method="POST")
|
@webmethod(route="/vector-io/query", method="POST")
|
||||||
|
@ -53,5 +53,5 @@ class VectorIO(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse: ...
|
) -> QueryChunksResponse: ...
|
||||||
|
|
|
@ -13,7 +13,6 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
@ -102,7 +101,7 @@ class DownloadTask:
|
||||||
output_file: str
|
output_file: str
|
||||||
total_size: int = 0
|
total_size: int = 0
|
||||||
downloaded_size: int = 0
|
downloaded_size: int = 0
|
||||||
task_id: Optional[int] = None
|
task_id: int | None = None
|
||||||
retries: int = 0
|
retries: int = 0
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
@ -262,7 +261,7 @@ class ParallelDownloader:
|
||||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
|
|
||||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||||
try:
|
try:
|
||||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||||
|
@ -282,7 +281,7 @@ class ParallelDownloader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||||
|
|
||||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
raise ValueError("No download tasks provided")
|
raise ValueError("No download tasks provided")
|
||||||
|
|
||||||
|
@ -391,20 +390,20 @@ def _meta_download(
|
||||||
|
|
||||||
class ModelEntry(BaseModel):
|
class ModelEntry(BaseModel):
|
||||||
model_id: str
|
model_id: str
|
||||||
files: Dict[str, str]
|
files: dict[str, str]
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
class Manifest(BaseModel):
|
||||||
models: List[ModelEntry]
|
models: list[ModelEntry]
|
||||||
expires_on: datetime
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
with open(manifest_file, "r") as f:
|
with open(manifest_file) as f:
|
||||||
d = json.load(f)
|
d = json.load(f)
|
||||||
manifest = Manifest(**d)
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel):
|
||||||
max_seq_length: int = 512
|
max_seq_length: int = 512
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
def descriptor(self) -> str:
|
def descriptor(self) -> str:
|
||||||
return self.model_id
|
return self.model_id
|
||||||
|
@ -44,11 +44,11 @@ def prompt_guard_model_skus():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_model_sku_map() -> Dict[str, Any]:
|
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
|
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||||
return {
|
return {
|
||||||
model.model_id: LlamaDownloadInfo(
|
model.model_id: LlamaDownloadInfo(
|
||||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||||
|
|
|
@ -13,13 +13,12 @@ import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.completion import WordCompleter
|
from prompt_toolkit.completion import WordCompleter
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import colored, cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
|
||||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def available_templates_specs() -> Dict[str, BuildConfig]:
|
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
template_specs = {}
|
template_specs = {}
|
||||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||||
template_name = p.parent.name
|
template_name = p.parent.name
|
||||||
with open(p, "r") as f:
|
with open(p) as f:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
template_specs[template_name] = build_config
|
template_specs[template_name] = build_config
|
||||||
return template_specs
|
return template_specs
|
||||||
|
@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if not available_providers:
|
if not available_providers:
|
||||||
continue
|
continue
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for API {}: ".format(api.value),
|
f"> Enter provider for API {api.value}: ",
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
complete_while_typing=True,
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
|
@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
|
|
||||||
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||||
else:
|
else:
|
||||||
with open(args.config, "r") as f:
|
with open(args.config) as f:
|
||||||
try:
|
try:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -332,9 +331,9 @@ def _generate_run_config(
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
image_name: Optional[str] = None,
|
image_name: str | None = None,
|
||||||
template_name: Optional[str] = None,
|
template_name: str | None = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
@ -389,6 +388,11 @@ def _run_stack_build_command_from_build_config(
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
|
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
||||||
|
cprint(
|
||||||
|
"You can run the new Llama Stack distro via: "
|
||||||
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
||||||
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
return _generate_run_config(build_config, build_dir, image_name)
|
return _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
|
@ -119,7 +119,7 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if not config_file.is_file():
|
if not config_file.is_file():
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
f"Config file must be a valid file path, '{config_file}’ is not a file: type={type(config_file)}"
|
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Using run configuration: {config_file}")
|
logger.info(f"Using run configuration: {config_file}")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
|
@ -9,7 +9,6 @@ import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
|
||||||
class VerificationResult:
|
class VerificationResult:
|
||||||
filename: str
|
filename: str
|
||||||
expected_hash: str
|
expected_hash: str
|
||||||
actual_hash: Optional[str]
|
actual_hash: str | None
|
||||||
exists: bool
|
exists: bool
|
||||||
matches: bool
|
matches: bool
|
||||||
|
|
||||||
|
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||||
return md5_hash.hexdigest()
|
return md5_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||||
checksums = {}
|
checksums = {}
|
||||||
with open(checklist_path, "r") as f:
|
with open(checklist_path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line.strip():
|
if line.strip():
|
||||||
md5sum, filepath = line.strip().split(" ", 1)
|
md5sum, filepath = line.strip().split(" ", 1)
|
||||||
|
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||||
return checksums
|
return checksums
|
||||||
|
|
||||||
|
|
||||||
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
def check_access(
|
def check_access(
|
||||||
obj_identifier: str,
|
obj_identifier: str,
|
||||||
obj_attributes: Optional[AccessAttributes],
|
obj_attributes: AccessAttributes | None,
|
||||||
user_attributes: Optional[Dict[str, Any]] = None,
|
user_attributes: dict[str, Any] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the current user has access to the given object, based on access attributes.
|
"""Check if the current user has access to the given object, based on access attributes.
|
||||||
|
|
||||||
|
|
|
@ -47,14 +47,13 @@ def get_provider_dependencies(
|
||||||
providers = config.distribution_spec.providers
|
providers = config.distribution_spec.providers
|
||||||
deps = []
|
deps = []
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config)
|
||||||
|
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
providers_for_api = registry[Api(api_str)]
|
providers_for_api = registry[Api(api_str)]
|
||||||
|
|
||||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
# Providers from BuildConfig and RunConfig are subtly different - not great
|
||||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||||
|
|
||||||
if provider_type not in providers_for_api:
|
if provider_type not in providers_for_api:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Type, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
def create_api_client_class(protocol) -> Type:
|
def create_api_client_class(protocol) -> type:
|
||||||
if protocol in _CLIENT_CLASSES:
|
if protocol in _CLIENT_CLASSES:
|
||||||
return _CLIENT_CLASSES[protocol]
|
return _CLIENT_CLASSES[protocol]
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
provider_spec = registry[provider.provider_type]
|
provider_spec = registry[provider.provider_type]
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
try:
|
try:
|
||||||
|
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
|
|
||||||
def upgrade_from_routing_table(
|
def upgrade_from_routing_table(
|
||||||
config_dict: Dict[str, Any],
|
config_dict: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
|
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||||
version = config_dict.get("version", None)
|
version = config_dict.get("version", None)
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
|
||||||
|
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccessAttributes(BaseModel):
|
class AccessAttributes(BaseModel):
|
||||||
|
@ -47,17 +47,17 @@ class AccessAttributes(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
# Standard attribute categories - the minimal set we need now
|
||||||
roles: Optional[List[str]] = Field(
|
roles: list[str] | None = Field(
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||||
)
|
)
|
||||||
|
|
||||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||||
|
|
||||||
projects: Optional[List[str]] = Field(
|
projects: list[str] | None = Field(
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||||
)
|
)
|
||||||
|
|
||||||
namespaces: Optional[List[str]] = Field(
|
namespaces: list[str] | None = Field(
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
default=None, description="Namespace-based access control for resource isolation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class ResourceWithACL(Resource):
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_attributes: Optional[AccessAttributes] = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
|
@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
Model,
|
|
||||||
Shield,
|
|
||||||
VectorDB,
|
|
||||||
Dataset,
|
|
||||||
ScoringFn,
|
|
||||||
Benchmark,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
ModelWithACL
|
||||||
ModelWithACL,
|
| ShieldWithACL
|
||||||
ShieldWithACL,
|
| VectorDBWithACL
|
||||||
VectorDBWithACL,
|
| DatasetWithACL
|
||||||
DatasetWithACL,
|
| ScoringFnWithACL
|
||||||
ScoringFnWithACL,
|
| BenchmarkWithACL
|
||||||
BenchmarkWithACL,
|
| ToolWithACL
|
||||||
ToolWithACL,
|
| ToolGroupWithACL,
|
||||||
ToolGroupWithACL,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||||
Inference,
|
|
||||||
Safety,
|
|
||||||
VectorIO,
|
|
||||||
DatasetIO,
|
|
||||||
Scoring,
|
|
||||||
Eval,
|
|
||||||
ToolRuntime,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
|
@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "router"
|
provider_type: str = "router"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
|
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
routing_table_api: Api
|
routing_table_api: Api
|
||||||
module: str
|
module: str
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "routing_table"
|
provider_type: str = "routing_table"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
|
|
||||||
router_api: Api
|
router_api: Api
|
||||||
module: str
|
module: str
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
pip_packages: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: Optional[str] = Field(
|
description: str | None = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
providers: dict[str, str | list[str]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Provider Types for each of the APIs provided by this distribution. If you
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
|
@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
class Provider(BaseModel):
|
class Provider(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
category_levels: Dict[str, str] = Field(
|
category_levels: dict[str, str] = Field(
|
||||||
default_factory=Dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||||
)
|
)
|
||||||
|
@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
)
|
)
|
||||||
config: Dict[str, str] = Field(
|
config: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
@ -261,15 +241,15 @@ class ServerConfig(BaseModel):
|
||||||
ge=1024,
|
ge=1024,
|
||||||
le=65535,
|
le=65535,
|
||||||
)
|
)
|
||||||
tls_certfile: Optional[str] = Field(
|
tls_certfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS certificate file for HTTPS",
|
description="Path to TLS certificate file for HTTPS",
|
||||||
)
|
)
|
||||||
tls_keyfile: Optional[str] = Field(
|
tls_keyfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
auth: Optional[AuthenticationConfig] = Field(
|
auth: AuthenticationConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Authentication configuration for the server",
|
description="Authentication configuration for the server",
|
||||||
)
|
)
|
||||||
|
@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
||||||
this could be just a hash
|
this could be just a hash
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the container image if this package refers to a container",
|
description="Reference to the container image if this package refers to a container",
|
||||||
)
|
)
|
||||||
apis: List[str] = Field(
|
apis: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
|
|
||||||
providers: Dict[str, List[Provider]] = Field(
|
providers: dict[str, list[Provider]] = Field(
|
||||||
description="""
|
description="""
|
||||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||||
can be instantiated multiple times (with different configs) if necessary.
|
can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
metadata_store: Optional[KVStoreConfig] = Field(
|
metadata_store: KVStoreConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||||
|
@ -309,22 +289,22 @@ a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: List[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: List[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
default_factory=ServerConfig,
|
default_factory=ServerConfig,
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
external_providers_dir: Optional[str] = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
@ -338,11 +318,11 @@ class BuildConfig(BaseModel):
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
)
|
)
|
||||||
image_name: Optional[str] = Field(
|
image_name: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Name of the distribution to build",
|
description="Name of the distribution to build",
|
||||||
)
|
)
|
||||||
external_providers_dir: Optional[str] = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
return list(Api)
|
return list(Api)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
|
||||||
router_api: Api
|
router_api: Api
|
||||||
|
|
||||||
|
|
||||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
return [
|
return [
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.models,
|
routing_table_api=Api.models,
|
||||||
|
@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> list[Api]:
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
adapter = AdapterSpec(**spec_data["adapter"])
|
adapter = AdapterSpec(**spec_data["adapter"])
|
||||||
spec = remote_provider_spec(
|
spec = remote_provider_spec(
|
||||||
api=api,
|
api=api,
|
||||||
|
@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
spec = InlineProviderSpec(
|
spec = InlineProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"inline::{provider_name}",
|
provider_type=f"inline::{provider_name}",
|
||||||
|
@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
||||||
|
|
||||||
def get_provider_registry(
|
def get_provider_registry(
|
||||||
config=None,
|
config=None,
|
||||||
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
"""Get the provider registry, optionally including external providers.
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
This function loads both built-in providers and external providers from YAML files.
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
@ -133,7 +133,7 @@ def get_provider_registry(
|
||||||
ValueError: If any provider spec is invalid
|
ValueError: If any provider spec is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
logger.debug(f"Importing module {name}")
|
logger.debug(f"Importing module {name}")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
skip_logger_removal: bool = False,
|
skip_logger_removal: bool = False,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
|
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||||
"""Get health status for all providers.
|
"""Get health status for all providers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||||
Each API maps to a dictionary of provider IDs to their health responses.
|
Each API maps to a dictionary of provider IDs to their health responses.
|
||||||
"""
|
"""
|
||||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||||
timeout = 1.0
|
timeout = 1.0
|
||||||
|
|
||||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||||
|
|
|
@ -7,7 +7,8 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ContextManager, Dict, List, Optional
|
from contextlib import AbstractContextManager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
@ -17,11 +18,11 @@ log = logging.getLogger(__name__)
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||||
):
|
):
|
||||||
self.provider_data = provider_data or {}
|
self.provider_data = provider_data or {}
|
||||||
if auth_attributes:
|
if auth_attributes:
|
||||||
|
@ -63,7 +64,7 @@ class NeedsRequestProviderData:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
|
||||||
"""Parse provider data from request headers"""
|
"""Parse provider data from request headers"""
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-Provider-Data",
|
"X-LlamaStack-Provider-Data",
|
||||||
|
@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
||||||
|
|
||||||
|
|
||||||
def request_provider_data_context(
|
def request_provider_data_context(
|
||||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||||
) -> ContextManager:
|
) -> AbstractContextManager:
|
||||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||||
provider_data = parse_request_provider_data(headers)
|
provider_data = parse_request_provider_data(headers)
|
||||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||||
"""Helper to retrieve auth attributes from the provider data context"""
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
provider_data = PROVIDER_DATA_VAR.get()
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map() -> Dict[Api, Any]:
|
def api_protocol_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.providers: ProvidersAPI,
|
Api.providers: ProvidersAPI,
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
|
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> Dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
|
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Resolves provider implementations by:
|
Resolves provider implementations by:
|
||||||
1. Validating and organizing providers.
|
1. Validating and organizing providers.
|
||||||
|
@ -136,7 +136,7 @@ async def resolve_impls(
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
"""Generates specifications for automatically routed APIs."""
|
"""Generates specifications for automatically routed APIs."""
|
||||||
specs = {}
|
specs = {}
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
|
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
|
|
||||||
|
|
||||||
def validate_and_prepare_providers(
|
def validate_and_prepare_providers(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||||
|
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
|
|
||||||
def sort_providers_by_deps(
|
def sort_providers_by_deps(
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
) -> list[tuple[str, ProviderWithSpec]]:
|
||||||
"""Sorts providers based on their dependencies."""
|
"""Sorts providers based on their dependencies."""
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -236,11 +236,11 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||||
) -> Dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: Dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -263,9 +263,9 @@ async def instantiate_providers(
|
||||||
|
|
||||||
|
|
||||||
def topological_sort(
|
def topological_sort(
|
||||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
providers_with_specs: dict[str, list[ProviderWithSpec]],
|
||||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
) -> list[tuple[str, ProviderWithSpec]]:
|
||||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
def dfs(kv, visited: set[str], stack: list[str]):
|
||||||
api_str, providers = kv
|
api_str, providers = kv
|
||||||
visited.add(api_str)
|
visited.add(api_str)
|
||||||
|
|
||||||
|
@ -280,8 +280,8 @@ def topological_sort(
|
||||||
|
|
||||||
stack.append(api_str)
|
stack.append(api_str)
|
||||||
|
|
||||||
visited: Set[str] = set()
|
visited: set[str] = set()
|
||||||
stack: List[str] = []
|
stack: list[str] = []
|
||||||
|
|
||||||
for api_str, providers in providers_with_specs.items():
|
for api_str, providers in providers_with_specs.items():
|
||||||
if api_str not in visited:
|
if api_str not in visited:
|
||||||
|
@ -298,8 +298,8 @@ def topological_sort(
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
|
|
||||||
async def resolve_remote_stack_impls(
|
async def resolve_remote_stack_impls(
|
||||||
config: RemoteProviderConfig,
|
config: RemoteProviderConfig,
|
||||||
apis: List[str],
|
apis: list[str],
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
@ -23,7 +23,7 @@ from .routing_tables import (
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
|
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
|
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: List[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Optional[Telemetry] = None,
|
telemetry: Telemetry | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
|
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricEvent]:
|
) -> list[MetricEvent]:
|
||||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> list[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
|
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
messages: List[Message] | InterleavedContent,
|
messages: list[Message] | InterleavedContent,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> Optional[int]:
|
) -> int | None:
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
else:
|
else:
|
||||||
|
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = None,
|
tool_choice: ToolChoice | None = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
|
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||||
)
|
)
|
||||||
|
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
return await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
|
|
||||||
async def health(self) -> Dict[str, HealthResponse]:
|
async def health(self) -> dict[str, HealthResponse]:
|
||||||
health_statuses = {}
|
health_statuses = {}
|
||||||
timeout = 0.5
|
timeout = 0.5
|
||||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||||
|
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -808,8 +808,8 @@ class EvalRouter(Eval):
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: list[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
|
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
if provider_shield_id is None:
|
if provider_shield_id is None:
|
||||||
provider_shield_id = shield_id
|
provider_shield_id = shield_id
|
||||||
|
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB:
|
) -> VectorDB:
|
||||||
if provider_vector_db_id is None:
|
if provider_vector_db_id is None:
|
||||||
provider_vector_db_id = vector_db_id
|
provider_vector_db_id = vector_db_id
|
||||||
|
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
if isinstance(source, dict):
|
if isinstance(source, dict):
|
||||||
if source["type"] == "uri":
|
if source["type"] == "uri":
|
||||||
|
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if provider_scoring_fn_id is None:
|
if provider_scoring_fn_id is None:
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
tools = await self.get_all_with_type("tool")
|
tools = await self.get_all_with_type("tool")
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
|
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
access_attributes: Optional[AccessAttributes] = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Structured user attributes for attribute-based access control.
|
Structured user attributes for attribute-based access control.
|
||||||
|
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
message: Optional[str] = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
|
||||||
class AuthRequestContext(BaseModel):
|
class AuthRequestContext(BaseModel):
|
||||||
path: str = Field(description="The path of the request being authenticated")
|
path: str = Field(description="The path of the request being authenticated")
|
||||||
|
|
||||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||||
|
|
||||||
params: Dict[str, List[str]] = Field(
|
params: dict[str, list[str]] = Field(
|
||||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
|
||||||
"""Base configuration for authentication providers."""
|
"""Base configuration for authentication providers."""
|
||||||
|
|
||||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ class AuthProvider(ABC):
|
||||||
class KubernetesAuthProvider(AuthProvider):
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, str]):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.api_server_url = config["api_server_url"]
|
self.api_server_url = config["api_server_url"]
|
||||||
self.ca_cert_path = config.get("ca_cert_path")
|
self.ca_cert_path = config.get("ca_cert_path")
|
||||||
self._client = None
|
self._client = None
|
||||||
|
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
self._client = ApiClient(configuration)
|
self._client = ApiClient(configuration)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a Kubernetes token and return access attributes."""
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, str]):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.endpoint = config["endpoint"]
|
self.endpoint = config["endpoint"]
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if not self.endpoint:
|
if not self.endpoint:
|
||||||
raise ValueError("Authentication endpoint not configured")
|
raise ValueError("Authentication endpoint not configured")
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
|
|
@ -15,7 +15,7 @@ import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
|
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
|
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||||
|
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||||
if isinstance(exc, ValidationError):
|
if isinstance(exc, ValidationError):
|
||||||
exc = RequestValidationError(exc.errors())
|
exc = RequestValidationError(exc.errors())
|
||||||
|
|
||||||
|
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: Optional[argparse.Namespace] = None):
|
def main(args: argparse.Namespace | None = None):
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file, "r") as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
uvicorn.run(**uvicorn_config)
|
uvicorn.run(**uvicorn_config)
|
||||||
|
|
||||||
|
|
||||||
def extract_path_params(route: str) -> List[str]:
|
def extract_path_params(route: str) -> list[str]:
|
||||||
segments = route.split("/")
|
segments = route.split("/")
|
||||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||||
# to handle path params like {param:path}
|
# to handle path params like {param:path}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import importlib.resources
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ RESOURCES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
|
@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||||
|
|
||||||
|
|
||||||
def run_config_from_adhoc_config_spec(
|
def run_config_from_adhoc_config_spec(
|
||||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||||
) -> StackRunConfig:
|
) -> StackRunConfig:
|
||||||
"""
|
"""
|
||||||
Create an adhoc distribution from a list of API providers.
|
Create an adhoc distribution from a list of API providers.
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, List, Optional, Protocol, Tuple
|
from typing import Protocol
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||||
|
|
||||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||||
|
|
||||||
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||||
|
|
||||||
|
@ -40,13 +40,13 @@ KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
def _get_registry_key_range() -> Tuple[str, str]:
|
def _get_registry_key_range() -> tuple[str, str]:
|
||||||
"""Returns the start and end keys for the registry range query."""
|
"""Returns the start and end keys for the registry range query."""
|
||||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||||
return start_key, f"{start_key}\xff"
|
return start_key, f"{start_key}\xff"
|
||||||
|
|
||||||
|
|
||||||
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
|
||||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
for value in values:
|
for value in values:
|
||||||
|
@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
# Disk registry does not have a cache
|
# Disk registry does not have a cache
|
||||||
raise NotImplementedError("Disk registry does not have a cache")
|
raise NotImplementedError("Disk registry does not have a cache")
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
return _parse_registry_values(values)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
def __init__(self, kvstore: KVStore):
|
def __init__(self, kvstore: KVStore):
|
||||||
super().__init__(kvstore)
|
super().__init__(kvstore)
|
||||||
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
|
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._initialize_lock = asyncio.Lock()
|
self._initialize_lock = asyncio.Lock()
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
|
@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
|
||||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
return self.cache.get((type, identifier), None)
|
return self.cache.get((type, identifier), None)
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
async with self._locked_cache() as cache:
|
async with self._locked_cache() as cache:
|
||||||
return list(cache.values())
|
return list(cache.values())
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
cache_key = (type, identifier)
|
cache_key = (type, identifier)
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
|
|
||||||
|
|
||||||
async def create_dist_registry(
|
async def create_dist_registry(
|
||||||
metadata_store: Optional[KVStoreConfig],
|
metadata_store: KVStoreConfig | None,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||||
# instantiate kvstore for storing and retrieving distribution metadata
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
@ -23,7 +22,7 @@ class LlamaStackApi:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
|
||||||
"""Run scoring on a single row"""
|
"""Run scoring on a single row"""
|
||||||
if not scoring_params:
|
if not scoring_params:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
streamlit
|
|
||||||
pandas
|
|
||||||
llama-stack-client>=0.2.1
|
|
||||||
streamlit-option-menu
|
|
||||||
llama-stack>=0.2.1
|
llama-stack>=0.2.1
|
||||||
|
llama-stack-client>=0.2.1
|
||||||
|
pandas
|
||||||
|
streamlit
|
||||||
|
streamlit-option-menu
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Redact sensitive information from config before printing."""
|
"""Redact sensitive information from config before printing."""
|
||||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return [_redact_value(i) for i in v]
|
return [_redact_value(i) for i in v]
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import AsyncGenerator, List, TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def preserve_contexts_async_generator(
|
def preserve_contexts_async_generator(
|
||||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||||
) -> AsyncGenerator[T, None]:
|
) -> AsyncGenerator[T, None]:
|
||||||
"""
|
"""
|
||||||
Wraps an async generator to preserve context variables across iterations.
|
Wraps an async generator to preserve context variables across iterations.
|
||||||
|
|
|
@ -8,12 +8,11 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
"""Check if a field type is a List of primitive types."""
|
"""Check if a field type is a List of primitive types."""
|
||||||
origin = get_origin(field_type)
|
origin = get_origin(field_type)
|
||||||
if origin is List or origin is list:
|
if origin is list or origin is list:
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
if len(args) == 1 and args[0] in (int, float, str, bool):
|
if len(args) == 1 and args[0] in (int, float, str, bool):
|
||||||
return True
|
return True
|
||||||
|
@ -53,7 +52,7 @@ def get_non_none_type(field_type):
|
||||||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||||
|
|
||||||
|
|
||||||
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
|
||||||
validators = model.__pydantic_decorators__.field_validators
|
validators = model.__pydantic_decorators__.field_validators
|
||||||
for _name, validator in validators.items():
|
for _name, validator in validators.items():
|
||||||
if field_name in validator.info.fields:
|
if field_name in validator.info.fields:
|
||||||
|
@ -126,7 +125,7 @@ def prompt_for_discriminated_union(
|
||||||
#
|
#
|
||||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||||
# unit tests for coverage.
|
# unit tests for coverage.
|
||||||
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
|
@ -33,7 +32,7 @@ CATEGORIES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize category levels with default level
|
# Initialize category levels with default level
|
||||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
_category_levels: dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||||
|
|
||||||
|
|
||||||
def config_to_category_levels(category: str, level: str):
|
def config_to_category_levels(category: str, level: str):
|
||||||
|
@ -49,7 +48,7 @@ def config_to_category_levels(category: str, level: str):
|
||||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
category_levels: Dict[str, int] = {}
|
category_levels: dict[str, int] = {}
|
||||||
level_value = logging._nameToLevel.get(str(level).upper())
|
level_value = logging._nameToLevel.get(str(level).upper())
|
||||||
if level_value is None:
|
if level_value is None:
|
||||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||||
|
@ -69,7 +68,7 @@ def config_to_category_levels(category: str, level: str):
|
||||||
return category_levels
|
return category_levels
|
||||||
|
|
||||||
|
|
||||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
def parse_yaml_config(yaml_config: LoggingConfig) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||||
return category_levels
|
return category_levels
|
||||||
|
|
||||||
|
|
||||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
def parse_environment_config(env_config: str) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||||
|
|
||||||
|
@ -131,7 +130,7 @@ class CustomRichHandler(RichHandler):
|
||||||
self.markup = original_markup
|
self.markup = original_markup
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
|
||||||
"""
|
"""
|
||||||
Configure logging based on the provided category log levels and an optional log file.
|
Configure logging based on the provided category log levels and an optional log file.
|
||||||
|
|
||||||
|
@ -211,7 +210,7 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
|
||||||
|
|
||||||
|
|
||||||
def get_logger(
|
def get_logger(
|
||||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
name: str, category: str = "uncategorized", config: LoggingConfig | None | None = None
|
||||||
) -> logging.LoggerAdapter:
|
) -> logging.LoggerAdapter:
|
||||||
"""
|
"""
|
||||||
Returns a logger with the specified name and category.
|
Returns a logger with the specified name and category.
|
||||||
|
|
|
@ -7,14 +7,14 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||||
|
|
||||||
|
|
||||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
|
||||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||||
if new_mp_size % old_mp_size == 0:
|
if new_mp_size % old_mp_size == 0:
|
||||||
# Read old MP shard and split it into smaller ones
|
# Read old MP shard and split it into smaller ones
|
||||||
|
@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in
|
||||||
|
|
||||||
|
|
||||||
def maybe_reshard_state_dict(
|
def maybe_reshard_state_dict(
|
||||||
ckpt_paths: List[Path],
|
ckpt_paths: list[Path],
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
moe_num_experts: Optional[int] = None,
|
moe_num_experts: int | None = None,
|
||||||
map_location: Union[str, torch.device] = "cpu",
|
map_location: str | torch.device = "cpu",
|
||||||
mmap: bool = True,
|
mmap: bool = True,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
if str(map_location) == "cpu":
|
if str(map_location) == "cpu":
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
|
@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||||
|
|
||||||
|
|
||||||
def reshard_mp(
|
def reshard_mp(
|
||||||
state_dicts: List[Dict[str, torch.Tensor]],
|
state_dicts: list[dict[str, torch.Tensor]],
|
||||||
size: int,
|
size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
repeat_qk_qv: int = 1,
|
repeat_qk_qv: int = 1,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||||
If the list has more than one state dict, we concatenate the values of the same
|
If the list has more than one state dict, we concatenate the values of the same
|
||||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
|
||||||
if len(tensors) > 1:
|
if len(tensors) > 1:
|
||||||
return torch.cat(tensors, dim=dim)
|
return torch.cat(tensors, dim=dim)
|
||||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||||
|
@ -144,7 +144,7 @@ def reshard_mp(
|
||||||
column_regex = re.compile("|".join(column_keys))
|
column_regex = re.compile("|".join(column_keys))
|
||||||
row_regex = re.compile("|".join(row_keys))
|
row_regex = re.compile("|".join(row_keys))
|
||||||
|
|
||||||
output: Dict[str, torch.Tensor] = {}
|
output: dict[str, torch.Tensor] = {}
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
# Note: only processes keys in the first state dict.
|
# Note: only processes keys in the first state dict.
|
||||||
# Assumes keys are the same across all state dicts.
|
# Assumes keys are the same across all state dicts.
|
||||||
|
@ -154,7 +154,7 @@ def reshard_mp(
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
|
||||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
routed_regex = re.compile("|".join(routed_keys))
|
routed_regex = re.compile("|".join(routed_keys))
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
|
|
|
@ -7,10 +7,9 @@
|
||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
# The goal is that these set of types are relevant for all Llama models.
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
|
@ -31,21 +30,21 @@ class BuiltinTool(Enum):
|
||||||
code_interpreter = "code_interpreter"
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
Primitive = Union[str, int, float, bool, None]
|
Primitive = str | int | float | bool | None
|
||||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
# Plan is to deprecate the Dict in favor of a JSON string
|
# Plan is to deprecate the Dict in favor of a JSON string
|
||||||
# that is parsed on the client side instead of trying to manage
|
# that is parsed on the client side instead of trying to manage
|
||||||
# the recursive type here.
|
# the recursive type here.
|
||||||
# Making this a union so that client side can start prepping for this change.
|
# Making this a union so that client side can start prepping for this change.
|
||||||
# Eventually, we will remove both the Dict and arguments_json field,
|
# Eventually, we will remove both the Dict and arguments_json field,
|
||||||
# and arguments will just be a str
|
# and arguments will just be a str
|
||||||
arguments: Union[str, Dict[str, RecursiveType]]
|
arguments: str | dict[str, RecursiveType]
|
||||||
arguments_json: Optional[str] = None
|
arguments_json: str | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -91,15 +90,15 @@ class StopReason(Enum):
|
||||||
|
|
||||||
class ToolParamDefinition(BaseModel):
|
class ToolParamDefinition(BaseModel):
|
||||||
param_type: str
|
param_type: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
required: Optional[bool] = True
|
required: bool | None = True
|
||||||
default: Optional[Any] = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolDefinition(BaseModel):
|
class ToolDefinition(BaseModel):
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
parameters: dict[str, ToolParamDefinition] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -119,7 +118,7 @@ class RawMediaItem(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@field_serializer("data")
|
@field_serializer("data")
|
||||||
def serialize_data(self, data: Optional[bytes], _info):
|
def serialize_data(self, data: bytes | None, _info):
|
||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
return base64.b64encode(data).decode("utf-8")
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
@ -137,9 +136,9 @@ class RawTextItem(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")]
|
||||||
|
|
||||||
RawContent = str | RawContentItem | List[RawContentItem]
|
RawContent = str | RawContentItem | list[RawContentItem]
|
||||||
|
|
||||||
|
|
||||||
class RawMessage(BaseModel):
|
class RawMessage(BaseModel):
|
||||||
|
@ -147,17 +146,17 @@ class RawMessage(BaseModel):
|
||||||
content: RawContent
|
content: RawContent
|
||||||
|
|
||||||
# This is for RAG but likely should be absorbed into content
|
# This is for RAG but likely should be absorbed into content
|
||||||
context: Optional[RawContent] = None
|
context: RawContent | None = None
|
||||||
|
|
||||||
# These are for the output message coming from the assistant
|
# These are for the output message coming from the assistant
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class GenerationResult(BaseModel):
|
class GenerationResult(BaseModel):
|
||||||
token: int
|
token: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[List[float]] = None
|
logprobs: list[float] | None = None
|
||||||
|
|
||||||
source: Literal["input"] | Literal["output"]
|
source: Literal["input"] | Literal["output"]
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationScheme(Enum):
|
class QuantizationScheme(Enum):
|
||||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuantizationArgs:
|
class QuantizationArgs:
|
||||||
scheme: Optional[QuantizationScheme] = None
|
scheme: QuantizationScheme | None = None
|
||||||
group_size: Optional[int] = None
|
group_size: int | None = None
|
||||||
spinquant: bool = False
|
spinquant: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -39,10 +38,10 @@ class ModelArgs:
|
||||||
dim: int = 4096
|
dim: int = 4096
|
||||||
n_layers: int = 32
|
n_layers: int = 32
|
||||||
n_heads: int = 32
|
n_heads: int = 32
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: int | None = None
|
||||||
vocab_size: int = -1
|
vocab_size: int = -1
|
||||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
ffn_dim_multiplier: Optional[float] = None
|
ffn_dim_multiplier: float | None = None
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
rope_theta: float = 500000
|
rope_theta: float = 500000
|
||||||
use_scaled_rope: bool = False
|
use_scaled_rope: bool = False
|
||||||
|
@ -55,8 +54,8 @@ class ModelArgs:
|
||||||
vision_max_num_chunks: int = 4
|
vision_max_num_chunks: int = 4
|
||||||
vision_num_cross_attention_layers: int = -1
|
vision_num_cross_attention_layers: int = -1
|
||||||
|
|
||||||
quantization_args: Optional[QuantizationArgs] = None
|
quantization_args: QuantizationArgs | None = None
|
||||||
lora_args: Optional[LoRAArgs] = None
|
lora_args: LoRAArgs | None = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
|
|
@ -8,7 +8,6 @@ import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionInput:
|
class VisionInput:
|
||||||
mask: List[List[int]]
|
mask: list[list[int]]
|
||||||
images: List[PIL_Image.Image]
|
images: list[PIL_Image.Image]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMInput:
|
class LLMInput:
|
||||||
tokens: List[int]
|
tokens: list[int]
|
||||||
vision: Optional[VisionInput] = None
|
vision: VisionInput | None = None
|
||||||
|
|
||||||
|
|
||||||
def role_str(role: Role) -> str:
|
def role_str(role: Role) -> str:
|
||||||
|
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
|
||||||
|
|
||||||
|
|
||||||
class ChatFormat:
|
class ChatFormat:
|
||||||
possible_headers: Dict[Role, str]
|
possible_headers: dict[Role, str]
|
||||||
|
|
||||||
def __init__(self, tokenizer: Tokenizer):
|
def __init__(self, tokenizer: Tokenizer):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -58,7 +57,7 @@ class ChatFormat:
|
||||||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||||
|
|
||||||
def _encode_header(self, role: str) -> List[int]:
|
def _encode_header(self, role: str) -> list[int]:
|
||||||
tokens = []
|
tokens = []
|
||||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||||
|
@ -70,7 +69,7 @@ class ChatFormat:
|
||||||
tokens, images = self._encode_content(content, bos=True)
|
tokens, images = self._encode_content(content, bos=True)
|
||||||
return self._model_input_from_tokens_images(tokens, images)
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||||
tokens = []
|
tokens = []
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
@ -107,7 +106,7 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_message(
|
def encode_message(
|
||||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||||
tokens = self._encode_header(message.role)
|
tokens = self._encode_header(message.role)
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
@ -145,8 +144,8 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_dialog_prompt(
|
def encode_dialog_prompt(
|
||||||
self,
|
self,
|
||||||
messages: List[RawMessage],
|
messages: list[RawMessage],
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> LLMInput:
|
) -> LLMInput:
|
||||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||||
tokens = []
|
tokens = []
|
||||||
|
@ -163,7 +162,7 @@ class ChatFormat:
|
||||||
return self._model_input_from_tokens_images(tokens, images)
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
# TODO(this should be generic, not only for assistant messages)
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||||
content = self.tokenizer.decode(tokens)
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
@ -234,7 +233,7 @@ class ChatFormat:
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
|
||||||
vision_input = None
|
vision_input = None
|
||||||
if len(images) > 0:
|
if len(images) > 0:
|
||||||
vision_input = VisionInput(
|
vision_input = VisionInput(
|
||||||
|
@ -249,9 +248,9 @@ class ChatFormat:
|
||||||
|
|
||||||
|
|
||||||
def create_vision_mask(
|
def create_vision_mask(
|
||||||
tokens: List[int],
|
tokens: list[int],
|
||||||
vision_token: int,
|
vision_token: int,
|
||||||
) -> List[List[int]]:
|
) -> list[list[int]]:
|
||||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||||
if len(vision_token_locations) == 0:
|
if len(vision_token_locations) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -15,8 +15,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generator, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -41,8 +41,8 @@ class Llama3:
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
world_size: Optional[int] = None,
|
world_size: int | None = None,
|
||||||
quantization_mode: Optional[QuantizationMode] = None,
|
quantization_mode: QuantizationMode | None = None,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
|
@ -82,7 +82,7 @@ class Llama3:
|
||||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
model_args: ModelArgs = ModelArgs(
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
@ -154,15 +154,15 @@ class Llama3:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
llm_inputs: List[LLMInput],
|
llm_inputs: list[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
print_model_input: bool = False,
|
print_model_input: bool = False,
|
||||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
@ -302,13 +302,13 @@ class Llama3:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
contents: List[RawContent],
|
contents: list[RawContent],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
model_inputs=model_inputs,
|
||||||
|
@ -324,14 +324,14 @@ class Llama3:
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages_batch: List[List[RawMessage]],
|
messages_batch: list[list[RawMessage]],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
model_inputs=model_inputs,
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
@ -131,7 +130,7 @@ class LLama31Interface:
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
self.tool_prompt_format = tool_prompt_format
|
self.tool_prompt_format = tool_prompt_format
|
||||||
|
|
||||||
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
def get_tokens(self, messages: list[RawMessage]) -> list[int]:
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
messages,
|
messages,
|
||||||
self.tool_prompt_format,
|
self.tool_prompt_format,
|
||||||
|
@ -149,10 +148,10 @@ class LLama31Interface:
|
||||||
|
|
||||||
def system_messages(
|
def system_messages(
|
||||||
self,
|
self,
|
||||||
builtin_tools: List[BuiltinTool],
|
builtin_tools: list[BuiltinTool],
|
||||||
custom_tools: List[ToolDefinition],
|
custom_tools: list[ToolDefinition],
|
||||||
instruction: Optional[str] = None,
|
instruction: str | None = None,
|
||||||
) -> List[RawMessage]:
|
) -> list[RawMessage]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
default_gen = SystemDefaultGenerator()
|
default_gen = SystemDefaultGenerator()
|
||||||
|
@ -194,8 +193,8 @@ class LLama31Interface:
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
stop_reason: StopReason,
|
stop_reason: StopReason,
|
||||||
tool_call: Optional[ToolCall] = None,
|
tool_call: ToolCall | None = None,
|
||||||
) -> List[RawMessage]:
|
) -> list[RawMessage]:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if tool_call:
|
if tool_call:
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
@ -208,7 +207,7 @@ class LLama31Interface:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def user_message(self, content: str) -> List[RawMessage]:
|
def user_message(self, content: str) -> list[RawMessage]:
|
||||||
return [RawMessage(role="user", content=content)]
|
return [RawMessage(role="user", content=content)]
|
||||||
|
|
||||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||||
|
@ -228,7 +227,7 @@ class LLama31Interface:
|
||||||
print("\n", end="")
|
print("\n", end="")
|
||||||
|
|
||||||
|
|
||||||
def list_jinja_templates() -> List[Template]:
|
def list_jinja_templates() -> list[Template]:
|
||||||
return TEMPLATES
|
return TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -80,7 +79,7 @@ def apply_rotary_emb(
|
||||||
xq: torch.Tensor,
|
xq: torch.Tensor,
|
||||||
xk: torch.Tensor,
|
xk: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
@ -162,7 +161,7 @@ class Attention(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor],
|
mask: torch.Tensor | None,
|
||||||
):
|
):
|
||||||
bsz, seqlen, _ = x.shape
|
bsz, seqlen, _ = x.shape
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
|
||||||
dim: int,
|
dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
multiple_of: int,
|
multiple_of: int,
|
||||||
ffn_dim_multiplier: Optional[float],
|
ffn_dim_multiplier: float | None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_dim = int(2 * hidden_dim / 3)
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor],
|
mask: torch.Tensor | None,
|
||||||
):
|
):
|
||||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
out = h + self.feed_forward(self.ffn_norm(h))
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as tv
|
import torchvision.transforms as tv
|
||||||
|
@ -26,7 +26,7 @@ IMAGE_RES = 224
|
||||||
logger = getLogger()
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
class VariableSizeImageTransform(object):
|
class VariableSizeImageTransform:
|
||||||
"""
|
"""
|
||||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||||
based on the image aspect ratio and the number of image chunks we allow.
|
based on the image aspect ratio and the number of image chunks we allow.
|
||||||
|
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
|
||||||
self.resample = tv.InterpolationMode.BILINEAR
|
self.resample = tv.InterpolationMode.BILINEAR
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_factors(n: int) -> Set[int]:
|
def get_factors(n: int) -> set[int]:
|
||||||
"""
|
"""
|
||||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||||
|
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_max_res_without_distortion(
|
def get_max_res_without_distortion(
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||||
aspect ratio, based on the target resolution.
|
aspect ratio, based on the target resolution.
|
||||||
|
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
|
||||||
def resize_without_distortion(
|
def resize_without_distortion(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
max_upscaling_size: Optional[int],
|
max_upscaling_size: int | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Used to resize an image to target_resolution, without distortion.
|
Used to resize an image to target_resolution, without distortion.
|
||||||
|
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
def get_best_fit(
|
def get_best_fit(
|
||||||
self,
|
self,
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
possible_resolutions: torch.Tensor,
|
possible_resolutions: torch.Tensor,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||||
resize an image to.
|
resize an image to.
|
||||||
|
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
normalize_img: bool = True,
|
normalize_img: bool = True,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[Any, Any]:
|
) -> tuple[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
image (PIL.Image): Image to be resized.
|
image (PIL.Image): Image to be resized.
|
||||||
|
|
|
@ -6,8 +6,9 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: Union[int, Tuple[int, int]],
|
kernel_size: int | tuple[int, int],
|
||||||
stride: Union[int, Tuple[int, int]],
|
stride: int | tuple[int, int],
|
||||||
bias: Optional[bool] = False,
|
bias: bool | None = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
missing_keys: List[str] = None,
|
missing_keys: list[str] = None,
|
||||||
unexpected_keys: List[str] = None,
|
unexpected_keys: list[str] = None,
|
||||||
error_msgs: List[str] = None,
|
error_msgs: list[str] = None,
|
||||||
return_state_dict: bool = False,
|
return_state_dict: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||||
|
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
|
||||||
dim: int,
|
dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
multiple_of: int,
|
multiple_of: int,
|
||||||
ffn_dim_multiplier: Optional[float],
|
ffn_dim_multiplier: float | None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the FeedForward module.
|
Initialize the FeedForward module.
|
||||||
|
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
xattn_mask: torch.Tensor,
|
xattn_mask: torch.Tensor,
|
||||||
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
|
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
|
||||||
xattn_cache: torch.Tensor,
|
xattn_cache: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_attn_out = self.attention(
|
_attn_out = self.attention(
|
||||||
|
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
def _init_fusion_schedule(
|
def _init_fusion_schedule(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
llama_layers = list(range(self.n_llama_layers))
|
llama_layers = list(range(self.n_llama_layers))
|
||||||
|
|
||||||
# uniformly spread the layers
|
# uniformly spread the layers
|
||||||
|
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
text_dtype,
|
text_dtype,
|
||||||
vision_tokens,
|
vision_tokens,
|
||||||
cross_attention_masks,
|
cross_attention_masks,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
assert vision_tokens is not None, "Vision tokens must be provided"
|
assert vision_tokens is not None, "Vision tokens must be provided"
|
||||||
vision_seqlen = vision_tokens.shape[3]
|
vision_seqlen = vision_tokens.shape[3]
|
||||||
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
||||||
|
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
|
|
||||||
def compute_vision_tokens_masks(
|
def compute_vision_tokens_masks(
|
||||||
self,
|
self,
|
||||||
batch_images: List[List[PIL_Image.Image]],
|
batch_images: list[list[PIL_Image.Image]],
|
||||||
batch_masks: List[List[List[int]]],
|
batch_masks: list[list[list[int]]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
skip_vision_encoder = False
|
skip_vision_encoder = False
|
||||||
|
|
||||||
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
||||||
|
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _stack_images(
|
def _stack_images(
|
||||||
images: List[List[PIL_Image.Image]],
|
images: list[list[PIL_Image.Image]],
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
image_res: int,
|
image_res: int,
|
||||||
max_num_images: int,
|
max_num_images: int,
|
||||||
) -> Tuple[torch.Tensor, List[int]]:
|
) -> tuple[torch.Tensor, list[int]]:
|
||||||
"""
|
"""
|
||||||
Takes a list of list of images and stacks them into a tensor.
|
Takes a list of list of images and stacks them into a tensor.
|
||||||
This function is needed since images can be of completely
|
This function is needed since images can be of completely
|
||||||
|
@ -1400,8 +1401,8 @@ def _stack_images(
|
||||||
|
|
||||||
|
|
||||||
def _pad_masks(
|
def _pad_masks(
|
||||||
all_masks: List[List[List[int]]],
|
all_masks: list[list[list[int]]],
|
||||||
all_num_chunks: List[List[int]],
|
all_num_chunks: list[list[int]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue