Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider

This commit is contained in:
Chantal D Gama Rose 2025-02-24 18:55:08 +00:00
commit ca6a12e362
114 changed files with 2100 additions and 685 deletions

View file

@ -11,17 +11,42 @@ on:
branches:
- main
paths:
- 'docs/source/**'
- 'docs/resources/**'
- 'docs/**'
- '.github/workflows/update-readthedocs.yml'
pull_request:
branches:
- main
paths:
- 'docs/**'
- '.github/workflows/update-readthedocs.yml'
jobs:
update-readthedocs:
runs-on: ubuntu-latest
runs-on: ubuntu-latest
env:
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v5
- name: Sync with uv
run: uv sync --extra docs
- name: Build HTML
run: |
cd docs
uv run make html
- name: Trigger ReadTheDocs build
if: github.event_name != 'pull_request'
run: |
if [ -z "$TOKEN" ]; then
echo "READTHEDOCS_TOKEN is not set"

View file

@ -30,6 +30,7 @@ repos:
rev: v0.9.4
hooks:
- id: ruff
args: [ --fix ]
exclude: ^llama_stack/strong_typing/.*$
- id: ruff-format
@ -45,23 +46,26 @@ repos:
hooks:
- id: uv-export
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--frozen",
"--no-hashes",
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.14.0
# hooks:
# - id: mypy
# additional_dependencies:
# - types-requests
# - types-setuptools
# - pydantic
# args: [--ignore-missing-imports]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
- uv==0.6.2
- mypy
- pytest
- rich
- types-requests
- pydantic
pass_filenames: false
# - repo: https://github.com/jsh9/pydoclint
# rev: d88180a8632bb1602a4d81344085cf320f288c5a

View file

@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
$ cd llama-stack/docs
$ uv sync --extra docs
# This rebuilds the documentation pages.
$ uv run make html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
$ make html
$ uv run sphinx-autobuild source build/html
$ uv run sphinx-autobuild source build/html --write-all
```
### Update API Documentation
@ -145,7 +147,7 @@ If you modify or add new API endpoints, update the API documentation accordingly
```bash
$ uv sync --extra dev
$ ./docs/openapi_generator/run_openapi_generator.sh
$ uv run ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -3,3 +3,4 @@ include distributions/dependencies.json
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/*.json

View file

@ -78,18 +78,14 @@ You have two ways to install this repository:
```
* **Install from source**:
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
Then, run the following commands:
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10
conda activate stack
cd llama-stack
pip install -e .
uv sync
uv pip install -e .
```
### Documentation

View file

@ -30,9 +30,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"cerebras": [
"aiosqlite",
@ -170,9 +168,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"hf-serverless": [
"aiohttp",
@ -247,9 +243,7 @@
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
@ -290,9 +284,7 @@
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"zmq"
],
"nvidia": [
"aiosqlite",
@ -323,9 +315,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"ollama": [
"aiohttp",
@ -335,7 +325,6 @@
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
@ -356,11 +345,10 @@
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"remote-vllm": [
"aiosqlite",
@ -423,9 +411,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"tgi": [
"aiohttp",

View file

@ -2315,6 +2315,70 @@
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
"200": {
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Turn"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
}
}
}
}
},
"tags": [
"Agents"
],
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "session_id",
"in": "path",
"description": "The ID of the session to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "turn_id",
"in": "path",
"description": "The ID of the turn to resume.",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
}
}
},
"required": true
}
}
},
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
"post": {
"responses": {
@ -4226,6 +4290,9 @@
},
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"allow_turn_resume": {
"type": "boolean"
}
},
"additionalProperties": false,
@ -4454,6 +4521,31 @@
},
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -4612,6 +4704,9 @@
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
],
"discriminator": {
@ -4621,7 +4716,8 @@
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
}
},
@ -4784,6 +4880,25 @@
"title": "AgentTurnResponseStreamChunk",
"description": "streamed agent turn completion response."
},
"AgentTurnResponseTurnAwaitingInputPayload": {
"type": "object",
"properties": {
"event_type": {
"type": "string",
"const": "turn_awaiting_input",
"default": "turn_awaiting_input"
},
"turn": {
"$ref": "#/components/schemas/Turn"
}
},
"additionalProperties": false,
"required": [
"event_type",
"turn"
],
"title": "AgentTurnResponseTurnAwaitingInputPayload"
},
"AgentTurnResponseTurnCompletePayload": {
"type": "object",
"properties": {
@ -4929,11 +5044,42 @@
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
},
"contents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContent"
},
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
"oneOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
},
"text_truncation": {
"type": "string",
"enum": [
"none",
"start",
"end"
],
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
},
"output_dimension": {
"type": "integer",
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
},
"task_type": {
"type": "string",
"enum": [
"query",
"document"
],
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
}
},
"additionalProperties": false,
@ -6625,6 +6771,31 @@
},
"error_code": {
"type": "integer"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -7474,9 +7645,37 @@
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult"
},
"QueryChunksRequest": {
@ -8015,6 +8214,27 @@
],
"title": "RegisterVectorDbRequest"
},
"ResumeAgentTurnRequest": {
"type": "object",
"properties": {
"tool_responses": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
},
"description": "The tool call responses to resume the turn with."
},
"stream": {
"type": "boolean",
"description": "Whether to stream the response."
}
},
"additionalProperties": false,
"required": [
"tool_responses"
],
"title": "ResumeAgentTurnRequest"
},
"RunEvalRequest": {
"type": "object",
"properties": {

View file

@ -1401,6 +1401,53 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
'200':
description: >-
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
objects.
content:
application/json:
schema:
$ref: '#/components/schemas/Turn'
text/event-stream:
schema:
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
tags:
- Agents
description: >-
Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client
side tool calls, this endpoint can be used to submit the outputs from the
tool calls once they are ready.
parameters:
- name: agent_id
in: path
description: The ID of the agent to resume.
required: true
schema:
type: string
- name: session_id
in: path
description: The ID of the session to resume.
required: true
schema:
type: string
- name: turn_id
in: path
description: The ID of the turn to resume.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ResumeAgentTurnRequest'
required: true
/v1/eval/benchmarks/{benchmark_id}/jobs:
post:
responses:
@ -2740,6 +2787,8 @@ components:
$ref: '#/components/schemas/AgentTool'
tool_config:
$ref: '#/components/schemas/ToolConfig'
allow_turn_resume:
type: boolean
additionalProperties: false
required:
- messages
@ -2896,6 +2945,16 @@ components:
- type: string
content:
$ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- call_id
@ -2992,6 +3051,7 @@ components:
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
discriminator:
propertyName: event_type
mapping:
@ -3000,6 +3060,7 @@ components:
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
AgentTurnResponseStepCompletePayload:
type: object
properties:
@ -3106,6 +3167,21 @@ components:
- event
title: AgentTurnResponseStreamChunk
description: streamed agent turn completion response.
"AgentTurnResponseTurnAwaitingInputPayload":
type: object
properties:
event_type:
type: string
const: turn_awaiting_input
default: turn_awaiting_input
turn:
$ref: '#/components/schemas/Turn'
additionalProperties: false
required:
- event_type
- turn
title: >-
AgentTurnResponseTurnAwaitingInputPayload
AgentTurnResponseTurnCompletePayload:
type: object
properties:
@ -3224,13 +3300,39 @@ components:
The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint.
contents:
type: array
items:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: array
items:
type: string
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
description: >-
List of contents to generate embeddings for. Note that content can be
multimodal. The behavior depends on the model and provider. Some models
may only support text.
List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text.
text_truncation:
type: string
enum:
- none
- start
- end
description: >-
(Optional) Config for how to truncate text for embedding when text is
longer than the model's max sequence length.
output_dimension:
type: integer
description: >-
(Optional) Output dimensionality for the embeddings. Only supported by
Matryoshka models.
task_type:
type: string
enum:
- query
- document
description: >-
(Optional) How is the embedding being used? This is only supported by
asymmetric embedding models.
additionalProperties: false
required:
- model_id
@ -4289,6 +4391,16 @@ components:
type: string
error_code:
type: integer
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- content
@ -4862,7 +4974,19 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- metadata
title: RAGQueryResult
QueryChunksRequest:
type: object
@ -5179,6 +5303,22 @@ components:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
ResumeAgentTurnRequest:
type: object
properties:
tool_responses:
type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with.
stream:
type: boolean
description: Whether to stream the response.
additionalProperties: false
required:
- tool_responses
title: ResumeAgentTurnRequest
RunEvalRequest:
type: object
properties:

View file

@ -86,8 +86,6 @@
"# NBVAL_SKIP\n",
"\n",
"!apt-get install -y bubblewrap\n",
"import os\n",
"os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n",
"!pip install uv\n",
"!uv pip install llama-stack"
]
@ -3632,7 +3630,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "master",
"display_name": "toolchain",
"language": "python",
"name": "python3"
},

View file

@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
## API Providers
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.),
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
@ -33,7 +33,7 @@ Providers come in two flavors:
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally.
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## Resources
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:

View file

@ -15,7 +15,7 @@
from docutils import nodes
project = "llama-stack"
copyright = "2024, Meta"
copyright = "2025, Meta"
author = "Meta"
# -- General configuration ---------------------------------------------------

View file

@ -38,6 +38,7 @@ The following models are available by default:
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
- `baai/bge-m3 (baai/bge-m3)`
### Prerequisite: API Keys

View file

@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
- {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama))
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r
- [Android](ondevice_distro/android_sdk)
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).**
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).**
### Distribution Details

View file

@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::cerebras` |
| inference | `remote::cerebras`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi` |
| inference | `remote::tgi`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::fireworks` |
| inference | `remote::fireworks`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.

View file

@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::vllm` |
| inference | `remote::vllm`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi` |
| inference | `remote::tgi`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::together` |
| inference | `remote::together`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| **Provider** | **Environments** |
| :----: | :----: |
| FAISS | Single Node |
| SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted |
@ -88,6 +89,7 @@ self
introduction/index
getting_started/index
concepts/index
providers/index
distributions/index
distributions/selection
building_applications/index

View file

@ -0,0 +1,59 @@
# Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## Agents
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
## DatasetIO
Interfaces with datasets and data loaders.
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
## Inference
Runs inference with an LLM.
## Post Training
Fine-tunes a model.
## Safety
Applies safety policies to the output at a Systems (not only model) level.
## Scoring
Evaluates the outputs of the system.
## Telemetry
Collects telemetry data from the system.
## Tool Runtime
Is associated with the ToolGroup resouces.
## Vector IO
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector
io and database are used to store and retrieve documents for retrieval.
#### Vector IO Providers
The following providers (i.e., databases) are available for Vector IO:
```{toctree}
:maxdepth: 1
vector_io/faiss
vector_io/sqlite-vec
vector_io/chromadb
vector_io/pgvector
vector_io/qdrant
vector_io/weaviate
```

View file

@ -0,0 +1,36 @@
---
orphan: true
---
# Chroma
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
Chroma supports:
- Store embeddings and their metadata
- Vector search
- Full-text search
- Document storage
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Chrome in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use chroma.
3. Start storing and querying vectors.
## Installation
You can install chroma using pip:
```bash
pip install chromadb
```
## Documentation
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# Faiss
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Lightweight and easy to use
- Fully integrated with Llama Stack
- GPU support
## Usage
To use Faiss in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install Faiss using pip:
```bash
pip install faiss-cpu
```
## Documentation
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
more details about Faiss in general.

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Postgres PGVector
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install PGVector using docker:
```bash
docker pull pgvector/pgvector:pg17
```
## Documentation
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Qdrant
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Qdrant in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install Qdrant using docker:
```bash
docker pull qdrant/qdrant
```
## Documentation
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# SQLite-Vec
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly within an SQLite database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Lightweight and easy to use
- Fully integrated with Llama Stack
## Usage
To use SQLite-Vec in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors.
## Installation
You can install SQLite-Vec using pip:
```bash
pip install sqlite-vec
```
## Documentation
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# Weaviate
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
Weaviate supports:
- Store embeddings and their metadata
- Vector search
- Full-text search
- Hybrid search
- Document storage
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Weaviate in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use chroma.
3. Start storing and querying vectors.
## Installation
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.

View file

@ -171,7 +171,7 @@ The `llama model` command helps you explore the models interface.
llama model --help
```
```
usage: llama model [-h] {download,list,prompt-format,describe} ...
usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ...
Work with llama models
@ -179,15 +179,15 @@ options:
-h, --help show this help message and exit
model_subcommands:
{download,list,prompt-format,describe}
{download,list,prompt-format,describe,verify-download,remove}
```
### Describe
You can use the describe command to know more about a model:
```
llama model describe -m Llama3.2-3B-Instruct
```
### Describe
```
+-----------------------------+----------------------------------+
| Model | Llama3.2-3B-Instruct |
@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens.
### Remove model
You can run `llama model remove` to remove unecessary model:
```
llama model remove -m Llama-Guard-3-8B-int8
```

View file

@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
],
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
stream: Optional[bool] = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
@ -333,8 +355,34 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",

View file

@ -91,15 +91,18 @@ ParamType = register_schema(
name="ParamType",
)
"""
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str
class CustomType(BaseModel):
pylint: disable=syntax-error
type: Literal["custom"] = "custom"
validator_class: str
"""

View file

@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None
@field_validator("tool_name", mode="before")
@classmethod
@ -402,6 +403,30 @@ class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
:cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
:cvar start: Truncate from the start
:cvar end: Truncate from the end
"""
none = "none"
start = "start"
end = "end"
class EmbeddingTaskType(Enum):
"""How is the embedding being used? This is only supported by asymmetric embedding models.
:cvar query: Used for a query for semantic search.
:cvar document: Used at indexing time when ingesting documents.
"""
query = "query"
document = "document"
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@ -481,12 +506,18 @@ class Inference(Protocol):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...

View file

@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type

View file

@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
content: InterleavedContent
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class ToolStore(Protocol):

View file

@ -19,6 +19,13 @@ def _get_model_size(model_dir):
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
def _convert_to_model_descriptor(model):
for m in all_registered_models():
if model == m.descriptor().replace(":", "-"):
return str(m.descriptor())
return str(model)
def _run_model_list_downloaded_cmd() -> None:
headers = ["Model", "Size", "Modified Time"]
@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None:
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
rows.append(
[
model,
_convert_to_model_descriptor(model),
model_size,
modified_time,
]
@ -68,6 +75,13 @@ class ModelList(Subcommand):
action="store_true",
help="List the downloaded models",
)
self.parser.add_argument(
"-s",
"--search",
type=str,
required=False,
help="Search for the input string as a substring in the model descriptor(ID)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
@ -87,15 +101,19 @@ class ModelList(Subcommand):
continue
descriptor = model.descriptor()
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
if not args.search or args.search.lower() in descriptor.lower():
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
)
if len(rows) == 0:
print(f"Did not find any model matching `{args.search}`.")
else:
print_table(
rows,
headers,
separate_rows=True,
)
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers)

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import shutil
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import resolve_model
class ModelRemove(Subcommand):
"""Remove the downloaded llama model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"remove",
prog="llama model remove",
description="Remove the downloaded llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_remove_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model",
required=True,
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
)
self.parser.add_argument(
"-f",
"--force",
action="store_true",
help="Used to forcefully remove the llama model from the storage without further confirmation",
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
prompt_guard = prompt_guard_model_sku()
if args.model == prompt_guard.model_id:
model = prompt_guard
else:
model = resolve_model(args.model)
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
if model is None or not os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return
if args.force:
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
print("Removal aborted.")

View file

@ -9,6 +9,7 @@ import importlib.resources
import json
import os
import shutil
import sys
import textwrap
from functools import lru_cache
from pathlib import Path
@ -23,10 +24,10 @@ from termcolor import cprint
from llama_stack.cli.table import print_table
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
ImageType,
build_image,
get_provider_dependencies,
)
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
@ -37,6 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -59,8 +62,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.list_templates:
return _run_template_list_cmd()
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if args.image_type == "venv":
current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv
if not image_name and in_notebook():
image_name = "__system__"
elif args.image_type == "conda":
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
else:
image_name = args.image_name
if args.template:
available_templates = available_templates_specs()
@ -69,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red",
)
return
sys.exit(1)
build_config = available_templates[args.template]
if args.image_type:
build_config.image_type = args.image_type
@ -78,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
)
return
sys.exit(1)
elif not args.config and not args.template:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
@ -159,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Could not parse config file {args.config}: {e}",
color="red",
)
return
sys.exit(1)
if build_config.image_type == ImageType.container.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
)
return
sys.exit(1)
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
@ -177,19 +188,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
print(f"uv pip install {special_dep}")
return
_run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
config_path=args.config,
template_name=args.template,
)
try:
run_config = _run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
config_path=args.config,
template_name=args.template,
)
except (Exception, RuntimeError) as exc:
cprint(
f"Error building stack: {exc}",
color="red",
)
sys.exit(1)
if run_config is None:
cprint(
"Run config path is empty",
color="red",
)
sys.exit(1)
if args.run:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args)
def _generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> None:
) -> str:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
@ -239,6 +272,7 @@ def _generate_run_config(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
return run_config_file
def _run_stack_build_command_from_build_config(
@ -246,7 +280,7 @@ def _run_stack_build_command_from_build_config(
image_name: Optional[str] = None,
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> None:
) -> str:
if build_config.image_type == ImageType.container.value:
if template_name:
image_name = f"distribution-{template_name}"
@ -256,6 +290,9 @@ def _run_stack_build_command_from_build_config(
elif build_config.image_type == ImageType.conda.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value:
if not image_name:
raise ValueError("Please specify an image name when building a venv image")
if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name
@ -276,7 +313,7 @@ def _run_stack_build_command_from_build_config(
template_or_config=template_name or config_path,
)
if return_code != 0:
return
raise RuntimeError(f"Failed to build image {image_name}")
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
@ -286,8 +323,9 @@ def _run_stack_build_command_from_build_config(
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green")
return template_path
else:
_generate_run_config(build_config, build_dir, image_name)
return _generate_run_config(build_config, build_dir, image_name)
def _run_template_list_cmd() -> None:

View file

@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if
help="Print the dependencies for the stack only, without building the stack",
)
self.parser.add_argument(
"--run",
action="store_true",
default=False,
help="Run the stack after building using the same image type, name, and other applicable arguments",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies

View file

@ -1,46 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama stack configure",
description="Configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
help="Path to the build config file (e.g. ~/.llama/builds/<image_type>/<name>-build.yaml). For container, this could also be the name of the container image. ",
)
self.parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build/<image_type>/<name>-run.yaml",
)
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
self.parser.error(
"""
DEPRECATED! llama stack configure has been deprecated.
Please use llama stack run <path/to/run.yaml> instead.
Please see example run.yaml in /distributions folder.
"""
)

View file

@ -74,10 +74,6 @@ class StackRun(Subcommand):
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
import json
import subprocess
import yaml
from termcolor import cprint
@ -87,7 +83,7 @@ class StackRun(Subcommand):
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
if not args.config:
self.parser.error("Must specify a config file to run")
@ -125,64 +121,7 @@ class StackRun(Subcommand):
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if args.image_type == ImageType.container.value or config.container_image:
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
elif args.image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if not image_name:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
)
return
def get_conda_prefix(env_name):
# Conda "base" environment does not end with "base" in the
# prefix, so should be handled separately.
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
return envpath
return None
print(f"Using conda environment: {image_name}")
conda_prefix = get_conda_prefix(image_name)
if not conda_prefix:
cprint(
f"Conda environment {image_name} does not exist.",
color="red",
)
return
build_file = Path(conda_prefix) / "llamastack-build.yaml"
if not build_file.exists():
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
)
return
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
run_args = [
script,
image_name,
]
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
venv = args.image_name or current_venv
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
run_args = [
script,
venv,
]
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
@ -206,5 +145,4 @@ class StackRun(Subcommand):
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args)

View file

@ -10,7 +10,6 @@ from importlib.metadata import version
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .list_apis import StackListApis
from .list_providers import StackListProviders
from .run import StackRun
@ -37,7 +36,6 @@ class StackParser(Subcommand):
# Add sub-commands
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)

View file

@ -7,7 +7,6 @@
import importlib.resources
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Dict, List
@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [
]
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"
class ApiInput(BaseModel):
api: Api
provider: str

View file

@ -8,6 +8,8 @@
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
BUILD_PLATFORM=${BUILD_PLATFORM:-}
@ -32,7 +34,7 @@ container_base="$3"
build_file_path="$4"
host_build_dir="$5"
pip_dependencies="$6"
special_pip_deps="$7"
special_pip_deps="${7:-}"
# Define color codes
@ -106,26 +108,39 @@ fi
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
client_mount="/app/llama-stack-client-source"
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
install_local_package() {
local dir="$1"
local mount_point="$2"
local name="$3"
if [ ! -d "$dir" ]; then
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
exit 1
fi
# Install in editable format. We will mount the source code into the container
# so that changes will be reflected in the container without having to do a
# rebuild. This is just for development convenience.
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $LLAMA_STACK_DIR $stack_mount
COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $stack_mount
RUN uv pip install --no-cache -e $mount_point
EOF
}
if [ -n "$LLAMA_MODELS_DIR" ]; then
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
@ -134,6 +149,7 @@ RUN uv pip install fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF
@ -149,23 +165,6 @@ EOF
fi
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $LLAMA_MODELS_DIR $models_mount
EOF
fi
add_to_container << EOF
RUN uv pip uninstall llama-models
RUN uv pip install --no-cache $models_mount
EOF
fi
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
@ -177,6 +176,15 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
EOF
fi
# Add other require item commands genearic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache
EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
cat $TEMP_DIR/Containerfile
printf "\n"
@ -189,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
fi
fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then

View file

@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
@ -24,8 +25,8 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
@ -34,8 +35,7 @@ special_pip_deps="$3"
set -euo pipefail
build_name="$1"
env_name="llamastack-$build_name"
env_name="$1"
pip_dependencies="$2"
# Define color codes
@ -74,9 +74,13 @@ run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
if [ -n "$UV_SYSTEM_PYTHON" ]; then
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment"
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
@ -90,6 +94,7 @@ run() {
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then

View file

@ -41,6 +41,7 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -52,19 +53,6 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
@ -230,12 +218,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}

View file

@ -6,7 +6,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -214,7 +220,10 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)

View file

@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do
esac
done
echo "Using virtual environment: $venv_path"
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2

View file

@ -12,8 +12,78 @@ import signal
import subprocess
import sys
from termcolor import cprint
log = logging.getLogger(__name__)
import importlib
import json
from pathlib import Path
from llama_stack.distribution.utils.image_types import ImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
if image_type == ImageType.container.value or config.container_image:
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
elif image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = image_name or current_conda_env
if not image_name:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
)
return
def get_conda_prefix(env_name):
# Conda "base" environment does not end with "base" in the
# prefix, so should be handled separately.
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
return envpath
return None
print(f"Using conda environment: {image_name}")
conda_prefix = get_conda_prefix(image_name)
if not conda_prefix:
cprint(
f"Conda environment {image_name} does not exist.",
color="red",
)
return
build_file = Path(conda_prefix) / "llamastack-build.yaml"
if not build_file.exists():
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
)
return
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
run_args = [
script,
image_name,
]
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
venv = image_name or current_venv
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
run_args = [
script,
venv,
]
return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
@ -22,6 +92,19 @@ def run_with_pty(command):
return _run_with_pty_unix(command)
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def _run_with_pty_unix(command):

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"

View file

@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
@ -60,9 +62,13 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
if output_message.tool_calls and request.allow_turn_resume:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
@ -456,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
metadata=result.metadata,
)
],
),
@ -611,11 +743,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
if tool_call.tool_name in client_tools:
yield message
return
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -625,6 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -639,6 +768,23 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
@ -650,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
},
) as span:
tool_execution_start_time = datetime.now()
result_messages = await execute_tool_call_maybe(
tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
[message],
tool_call,
toolgroup_args,
tool_to_group,
)
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@ -675,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
@ -913,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
# When this changes, we can drop this assert
# Whether to call tools on each message and aggregate
# or aggregate and call tool once, reamins to be seen.
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
@ -946,14 +1092,7 @@ async def execute_tool_call_maybe(
**tool_call_args,
),
)
return [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=result.content,
)
]
return result
def _interpret_content_as_attachment(

View file

@ -11,8 +11,6 @@ import tempfile
import uuid
from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored
from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
@ -21,6 +19,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
@ -68,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents):
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
async def create_agent(
self,
@ -146,6 +140,7 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -155,6 +150,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -169,6 +165,34 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None

View file

@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl(
pass
async def register_model(self, model: Model) -> None:
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:

View file

@ -22,11 +22,14 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -230,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
for c in chunks[: query_config.max_chunks]:
for c in chunks:
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(

View file

@ -4,26 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# config.py
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec.db",
)
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
pip_packages=["sentence-transformers"],
pip_packages=[
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
],
module="llama_stack.providers.inline.inference.sentence_transformers",
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
),

View file

@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[],
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference],

View file

@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import (
remote_provider_spec,
)
EMBEDDING_DEPS = [
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# this happens to work because special dependencies are always installed last
# so if there was a regular torch installed first, this would be ignored
# we need a better way to do this to identify potential conflicts, etc.
# for now, this lets us significantly reduce the size of the container which
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
pip_packages=["faiss-cpu"],
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
pip_packages=["faiss-cpu"],
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
provider_type="inline::sqlite-vec",
pip_packages=["sqlite-vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=["sqlite-vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
),
@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
pip_packages=["chromadb"],
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
),
@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="qdrant",
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
),

View file

@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -162,7 +167,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []

View file

@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -172,6 +177,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
@ -130,7 +135,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -204,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
@ -232,7 +236,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

View file

@ -17,16 +17,23 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
@ -140,7 +147,10 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
@ -46,6 +48,14 @@ _MODEL_ENTRIES = [
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
ProviderModelEntry(
provider_model_id="baai/bge-m3",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 8192,
},
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]

View file

@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -18,15 +23,20 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -117,9 +127,41 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
#
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
# ->
# OpenAI: input = str | List[str]
#
# we can ignore str and always pass List[str] to OpenAI
#
flat_contents = [
item.text if isinstance(item, TextContentItem) else item
for content in contents
for item in (content if isinstance(content, list) else [content])
]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id)
response = await self._client.embeddings.create(
model=model,
input=input,
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
# ->
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def chat_completion(
self,

View file

@ -13,6 +13,7 @@ from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
@ -20,11 +21,13 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -175,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
@ -184,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
@ -258,7 +262,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -274,7 +281,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:
response = await self.client.ps()
@ -284,7 +294,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
return await self.register_helper.register_model(model)
return model
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:

View file

@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
contents: List[InterleavedContent],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
return client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)

View file

@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
@ -119,6 +120,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -5,16 +5,38 @@
# the root directory of this source tree.
import json
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
TextTruncation,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
@ -119,7 +141,10 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -268,7 +273,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -198,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
@ -219,7 +223,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (

View file

@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -22,18 +28,21 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@ -112,10 +121,16 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
if tool_param.required:
compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"name": tool_name,
"description": tool.description,
"parameters": {
"type": "object",
@ -369,7 +384,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

View file

@ -8,7 +8,10 @@ import unittest
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
ToolCall,
ToolConfig,
UserMessage,
)
@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
)
MODEL = "Llama3.1-8B-Instruct"
@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_completion_message_encoding(self):
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments={"param1": "value1"},
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"

View file

@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture:
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),

View file

@ -5,13 +5,16 @@
# the root directory of this source tree.
import logging
from typing import List
from typing import List, Optional
from llama_stack.apis.inference import (
EmbeddingsResponse,
InterleavedContent,
EmbeddingTaskType,
InterleavedContentItem,
ModelStore,
TextTruncation,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
EMBEDDING_MODELS = {}
@ -25,11 +28,16 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(contents)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
)
return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":

View file

@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
if model.metadata.get("llama_model") is None:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
"Please specify a llama_model in metadata or use a supported model identifier"
)
llama_model = model.metadata.get("llama_model")
if llama_model is None:
return model
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
return model

View file

@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
)
return formatter.tokenizer.decode(model_input.tokens)
@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info(
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, TypeVar
from typing import Any, Callable, List, Optional, Protocol, TypeVar
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
T = TypeVar("T")
@dataclass
class WebMethod:
@ -22,6 +20,13 @@ class WebMethod:
raw_bytes_request_body: Optional[bool] = False
class HasWebMethod(Protocol):
__webmethod__: WebMethod
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
def webmethod(
route: Optional[str] = None,
method: Optional[str] = None,

View file

@ -11,7 +11,7 @@ import subprocess
import sys
from functools import partial
from pathlib import Path
from typing import Iterator
from typing import Iterable
from rich.progress import Progress, SpinnerColumn, TextColumn
@ -39,7 +39,7 @@ class ChangedPathTracker:
return self._changed_paths
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
return has_changes
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
try:
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)

View file

@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
pytest_args,
"-s",
"-v",
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
]
)

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::cerebras
- inline::sentence-transformers
safety:
- inline::llama-guard
vector_io:

View file

@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::cerebras"],
"inference": ["remote::cerebras", "inline::sentence-transformers"],
"safety": ["inline::llama-guard"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"agents": ["inline::meta-reference"],

View file

@ -5,6 +5,7 @@ distribution_spec:
providers:
inference:
- remote::tgi
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::tgi"],
"inference": ["remote::tgi", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::fireworks
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::fireworks"],
"inference": ["remote::fireworks", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::hf::serverless
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::hf::serverless"],
"inference": ["remote::hf::serverless", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -55,9 +55,11 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="nvidia",
model_type=m.model_type,
metadata=m.metadata,
)
for m in _MODEL_ENTRIES
]

View file

@ -135,6 +135,13 @@ models:
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata:
embedding_dimension: 1024
context_length: 8192
model_id: baai/bge-m3
provider_id: nvidia
provider_model_id: baai/bge-m3
model_type: embedding
shields: []
vector_dbs: []
datasets: []

View file

@ -5,7 +5,7 @@ distribution_spec:
inference:
- remote::ollama
vector_io:
- inline::faiss
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:

View file

@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::ollama"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
vector_io_provider_sqlite = Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
"inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite],
},
default_models=[inference_model, embedding_model],
default_models=[inference_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
embedding_provider,
],
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
"inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite],
"safety": [
Provider(
provider_id="llama-guard",

View file

@ -16,24 +16,11 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -16,24 +16,11 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: sqlite_vec
provider_type: inline::sqlite_vec
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -100,11 +87,14 @@ models:
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
<<<<<<< HEAD
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
=======
>>>>>>> upstream/main
shields: []
vector_dbs: []
datasets: []

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::vllm
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::vllm"],
"inference": ["remote::vllm", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::tgi
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::tgi"],
"inference": ["remote::tgi", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

Some files were not shown because too many files have changed in this diff Show more