Merge branch 'meta-llama:main' into main

This commit is contained in:
Jamie Land 2025-02-25 10:58:53 -05:00 committed by GitHub
commit b41afa5843
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
106 changed files with 2223 additions and 853 deletions

View file

@ -30,6 +30,7 @@ repos:
rev: v0.9.4
hooks:
- id: ruff
args: [ --fix ]
exclude: ^llama_stack/strong_typing/.*$
- id: ruff-format
@ -45,23 +46,26 @@ repos:
hooks:
- id: uv-export
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--frozen",
"--no-hashes",
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.14.0
# hooks:
# - id: mypy
# additional_dependencies:
# - types-requests
# - types-setuptools
# - pydantic
# args: [--ignore-missing-imports]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
- uv==0.6.2
- mypy
- pytest
- rich
- types-requests
- pydantic
pass_filenames: false
# - repo: https://github.com/jsh9/pydoclint
# rev: d88180a8632bb1602a4d81344085cf320f288c5a

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.10

View file

@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
$ cd llama-stack/docs
$ uv sync --extra docs
# This rebuilds the documentation pages.
$ uv run make html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
$ make html
$ uv run sphinx-autobuild source build/html
$ uv run sphinx-autobuild source build/html --write-all
```
### Update API Documentation

View file

@ -3,3 +3,4 @@ include distributions/dependencies.json
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/*.json

View file

@ -30,9 +30,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"cerebras": [
"aiosqlite",
@ -68,6 +66,41 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"ci-tests": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dell": [
"aiohttp",
"aiosqlite",
@ -170,9 +203,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"hf-serverless": [
"aiohttp",
@ -247,9 +278,7 @@
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
@ -290,9 +319,7 @@
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"zmq"
],
"nvidia": [
"aiosqlite",
@ -323,9 +350,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"ollama": [
"aiohttp",
@ -335,7 +360,6 @@
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
@ -356,11 +380,10 @@
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"remote-vllm": [
"aiosqlite",
@ -423,9 +446,7 @@
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
"uvicorn"
],
"tgi": [
"aiohttp",

View file

@ -2315,6 +2315,70 @@
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
"200": {
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Turn"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
}
}
}
}
},
"tags": [
"Agents"
],
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "session_id",
"in": "path",
"description": "The ID of the session to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "turn_id",
"in": "path",
"description": "The ID of the turn to resume.",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
}
}
},
"required": true
}
}
},
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
"post": {
"responses": {
@ -4226,6 +4290,9 @@
},
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"allow_turn_resume": {
"type": "boolean"
}
},
"additionalProperties": false,
@ -4454,6 +4521,31 @@
},
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -4612,6 +4704,9 @@
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
],
"discriminator": {
@ -4621,7 +4716,8 @@
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
}
},
@ -4784,6 +4880,25 @@
"title": "AgentTurnResponseStreamChunk",
"description": "streamed agent turn completion response."
},
"AgentTurnResponseTurnAwaitingInputPayload": {
"type": "object",
"properties": {
"event_type": {
"type": "string",
"const": "turn_awaiting_input",
"default": "turn_awaiting_input"
},
"turn": {
"$ref": "#/components/schemas/Turn"
}
},
"additionalProperties": false,
"required": [
"event_type",
"turn"
],
"title": "AgentTurnResponseTurnAwaitingInputPayload"
},
"AgentTurnResponseTurnCompletePayload": {
"type": "object",
"properties": {
@ -6656,6 +6771,31 @@
},
"error_code": {
"type": "integer"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -7505,9 +7645,37 @@
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult"
},
"QueryChunksRequest": {
@ -8046,6 +8214,27 @@
],
"title": "RegisterVectorDbRequest"
},
"ResumeAgentTurnRequest": {
"type": "object",
"properties": {
"tool_responses": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
},
"description": "The tool call responses to resume the turn with."
},
"stream": {
"type": "boolean",
"description": "Whether to stream the response."
}
},
"additionalProperties": false,
"required": [
"tool_responses"
],
"title": "ResumeAgentTurnRequest"
},
"RunEvalRequest": {
"type": "object",
"properties": {

View file

@ -1401,6 +1401,53 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
'200':
description: >-
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
objects.
content:
application/json:
schema:
$ref: '#/components/schemas/Turn'
text/event-stream:
schema:
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
tags:
- Agents
description: >-
Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client
side tool calls, this endpoint can be used to submit the outputs from the
tool calls once they are ready.
parameters:
- name: agent_id
in: path
description: The ID of the agent to resume.
required: true
schema:
type: string
- name: session_id
in: path
description: The ID of the session to resume.
required: true
schema:
type: string
- name: turn_id
in: path
description: The ID of the turn to resume.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ResumeAgentTurnRequest'
required: true
/v1/eval/benchmarks/{benchmark_id}/jobs:
post:
responses:
@ -2740,6 +2787,8 @@ components:
$ref: '#/components/schemas/AgentTool'
tool_config:
$ref: '#/components/schemas/ToolConfig'
allow_turn_resume:
type: boolean
additionalProperties: false
required:
- messages
@ -2896,6 +2945,16 @@ components:
- type: string
content:
$ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- call_id
@ -2992,6 +3051,7 @@ components:
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
discriminator:
propertyName: event_type
mapping:
@ -3000,6 +3060,7 @@ components:
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
AgentTurnResponseStepCompletePayload:
type: object
properties:
@ -3106,6 +3167,21 @@ components:
- event
title: AgentTurnResponseStreamChunk
description: streamed agent turn completion response.
"AgentTurnResponseTurnAwaitingInputPayload":
type: object
properties:
event_type:
type: string
const: turn_awaiting_input
default: turn_awaiting_input
turn:
$ref: '#/components/schemas/Turn'
additionalProperties: false
required:
- event_type
- turn
title: >-
AgentTurnResponseTurnAwaitingInputPayload
AgentTurnResponseTurnCompletePayload:
type: object
properties:
@ -4315,6 +4391,16 @@ components:
type: string
error_code:
type: integer
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- content
@ -4888,7 +4974,19 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- metadata
title: RAGQueryResult
QueryChunksRequest:
type: object
@ -5205,6 +5303,22 @@ components:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
ResumeAgentTurnRequest:
type: object
properties:
tool_responses:
type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with.
stream:
type: boolean
description: Whether to stream the response.
additionalProperties: false
required:
- tool_responses
title: ResumeAgentTurnRequest
RunEvalRequest:
type: object
properties:

View file

@ -86,10 +86,8 @@
"# NBVAL_SKIP\n",
"\n",
"!apt-get install -y bubblewrap\n",
"import os\n",
"os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n",
"!pip install uv\n",
"!uv pip install llama-stack"
"!uv pip install llama-stack --system"
]
},
{
@ -128,7 +126,7 @@
"source": [
"# NBVAL_SKIP\n",
"# This will build all the dependencies you will need\n",
"!llama stack build --template together --image-type venv"
"!llama stack build --template together --image-type venv --image-name __system__"
]
},
{
@ -3632,7 +3630,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "master",
"display_name": "toolchain",
"language": "python",
"name": "python3"
},

View file

@ -311,7 +311,7 @@
],
"source": [
"# NBVAL_SKIP\n",
"!llama stack build --template together --image-type venv"
"!llama stack build --template together --image-type venv --image-name __system__"
]
},
{

View file

@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
## API Providers
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.),
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
@ -33,7 +33,7 @@ Providers come in two flavors:
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally.
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## Resources
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:

View file

@ -15,7 +15,7 @@
from docutils import nodes
project = "llama-stack"
copyright = "2024, Meta"
copyright = "2025, Meta"
author = "Meta"
# -- General configuration ---------------------------------------------------

View file

@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
- {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama))
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r
- [Android](ondevice_distro/android_sdk)
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).**
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).**
### Distribution Details

View file

@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::cerebras` |
| inference | `remote::cerebras`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi` |
| inference | `remote::tgi`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::fireworks` |
| inference | `remote::fireworks`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.

View file

@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::vllm` |
| inference | `remote::vllm`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi` |
| inference | `remote::tgi`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::together` |
| inference | `remote::together`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| **Provider** | **Environments** |
| :----: | :----: |
| FAISS | Single Node |
| SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted |
@ -88,6 +89,7 @@ self
introduction/index
getting_started/index
concepts/index
providers/index
distributions/index
distributions/selection
building_applications/index

View file

@ -0,0 +1,59 @@
# Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## Agents
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
## DatasetIO
Interfaces with datasets and data loaders.
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
## Inference
Runs inference with an LLM.
## Post Training
Fine-tunes a model.
## Safety
Applies safety policies to the output at a Systems (not only model) level.
## Scoring
Evaluates the outputs of the system.
## Telemetry
Collects telemetry data from the system.
## Tool Runtime
Is associated with the ToolGroup resouces.
## Vector IO
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector
io and database are used to store and retrieve documents for retrieval.
#### Vector IO Providers
The following providers (i.e., databases) are available for Vector IO:
```{toctree}
:maxdepth: 1
vector_io/faiss
vector_io/sqlite-vec
vector_io/chromadb
vector_io/pgvector
vector_io/qdrant
vector_io/weaviate
```

View file

@ -0,0 +1,36 @@
---
orphan: true
---
# Chroma
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
Chroma supports:
- Store embeddings and their metadata
- Vector search
- Full-text search
- Document storage
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Chrome in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use chroma.
3. Start storing and querying vectors.
## Installation
You can install chroma using pip:
```bash
pip install chromadb
```
## Documentation
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# Faiss
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Lightweight and easy to use
- Fully integrated with Llama Stack
- GPU support
## Usage
To use Faiss in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install Faiss using pip:
```bash
pip install faiss-cpu
```
## Documentation
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
more details about Faiss in general.

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Postgres PGVector
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install PGVector using docker:
```bash
docker pull pgvector/pgvector:pg17
```
## Documentation
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Qdrant
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Qdrant in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
3. Start storing and querying vectors.
## Installation
You can install Qdrant using docker:
```bash
docker pull qdrant/qdrant
```
## Documentation
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# SQLite-Vec
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly within an SQLite database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Lightweight and easy to use
- Fully integrated with Llama Stack
## Usage
To use SQLite-Vec in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors.
## Installation
You can install SQLite-Vec using pip:
```bash
pip install sqlite-vec
```
## Documentation
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.

View file

@ -0,0 +1,33 @@
---
orphan: true
---
# Weaviate
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
Weaviate supports:
- Store embeddings and their metadata
- Vector search
- Full-text search
- Hybrid search
- Document storage
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Weaviate in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use chroma.
3. Start storing and querying vectors.
## Installation
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.

View file

@ -171,7 +171,7 @@ The `llama model` command helps you explore the models interface.
llama model --help
```
```
usage: llama model [-h] {download,list,prompt-format,describe} ...
usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ...
Work with llama models
@ -179,15 +179,15 @@ options:
-h, --help show this help message and exit
model_subcommands:
{download,list,prompt-format,describe}
{download,list,prompt-format,describe,verify-download,remove}
```
### Describe
You can use the describe command to know more about a model:
```
llama model describe -m Llama3.2-3B-Instruct
```
### Describe
```
+-----------------------------+----------------------------------+
| Model | Llama3.2-3B-Instruct |
@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens.
### Remove model
You can run `llama model remove` to remove unecessary model:
```
llama model remove -m Llama-Guard-3-8B-int8
```

View file

@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
],
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
stream: Optional[bool] = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
@ -333,8 +355,34 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",

View file

@ -91,15 +91,18 @@ ParamType = register_schema(
name="ParamType",
)
"""
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str
class CustomType(BaseModel):
pylint: disable=syntax-error
type: Literal["custom"] = "custom"
validator_class: str
"""

View file

@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None
@field_validator("tool_name", mode="before")
@classmethod

View file

@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type

View file

@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
content: InterleavedContent
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class ToolStore(Protocol):

View file

@ -19,6 +19,13 @@ def _get_model_size(model_dir):
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
def _convert_to_model_descriptor(model):
for m in all_registered_models():
if model == m.descriptor().replace(":", "-"):
return str(m.descriptor())
return str(model)
def _run_model_list_downloaded_cmd() -> None:
headers = ["Model", "Size", "Modified Time"]
@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None:
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
rows.append(
[
model,
_convert_to_model_descriptor(model),
model_size,
modified_time,
]
@ -68,6 +75,13 @@ class ModelList(Subcommand):
action="store_true",
help="List the downloaded models",
)
self.parser.add_argument(
"-s",
"--search",
type=str,
required=False,
help="Search for the input string as a substring in the model descriptor(ID)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
@ -87,15 +101,19 @@ class ModelList(Subcommand):
continue
descriptor = model.descriptor()
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
if not args.search or args.search.lower() in descriptor.lower():
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
]
)
if len(rows) == 0:
print(f"Did not find any model matching `{args.search}`.")
else:
print_table(
rows,
headers,
separate_rows=True,
)
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers)

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import shutil
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import resolve_model
class ModelRemove(Subcommand):
"""Remove the downloaded llama model"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"remove",
prog="llama model remove",
description="Remove the downloaded llama model",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_remove_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model",
required=True,
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
)
self.parser.add_argument(
"-f",
"--force",
action="store_true",
help="Used to forcefully remove the llama model from the storage without further confirmation",
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
prompt_guard = prompt_guard_model_sku()
if args.model == prompt_guard.model_id:
model = prompt_guard
else:
model = resolve_model(args.model)
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
if model is None or not os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return
if args.force:
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
print("Removal aborted.")

View file

@ -9,6 +9,7 @@ import importlib.resources
import json
import os
import shutil
import sys
import textwrap
from functools import lru_cache
from pathlib import Path
@ -23,10 +24,10 @@ from termcolor import cprint
from llama_stack.cli.table import print_table
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
ImageType,
build_image,
get_provider_dependencies,
)
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
@ -37,6 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -59,8 +62,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.list_templates:
return _run_template_list_cmd()
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if args.image_type == "venv":
current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv
if not image_name and in_notebook():
image_name = "__system__"
elif args.image_type == "conda":
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
else:
image_name = args.image_name
if args.template:
available_templates = available_templates_specs()
@ -69,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red",
)
return
sys.exit(1)
build_config = available_templates[args.template]
if args.image_type:
build_config.image_type = args.image_type
@ -78,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
)
return
sys.exit(1)
elif not args.config and not args.template:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
@ -159,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
f"Could not parse config file {args.config}: {e}",
color="red",
)
return
sys.exit(1)
if build_config.image_type == ImageType.container.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
)
return
sys.exit(1)
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
@ -177,19 +188,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
print(f"uv pip install {special_dep}")
return
_run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
config_path=args.config,
template_name=args.template,
)
try:
run_config = _run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
config_path=args.config,
template_name=args.template,
)
except (Exception, RuntimeError) as exc:
cprint(
f"Error building stack: {exc}",
color="red",
)
sys.exit(1)
if run_config is None:
cprint(
"Run config path is empty",
color="red",
)
sys.exit(1)
if args.run:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args)
def _generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> None:
) -> str:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
@ -239,6 +272,7 @@ def _generate_run_config(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
return run_config_file
def _run_stack_build_command_from_build_config(
@ -246,7 +280,7 @@ def _run_stack_build_command_from_build_config(
image_name: Optional[str] = None,
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> None:
) -> str:
if build_config.image_type == ImageType.container.value:
if template_name:
image_name = f"distribution-{template_name}"
@ -256,6 +290,9 @@ def _run_stack_build_command_from_build_config(
elif build_config.image_type == ImageType.conda.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value:
if not image_name:
raise ValueError("Please specify an image name when building a venv image")
if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name
@ -276,7 +313,7 @@ def _run_stack_build_command_from_build_config(
template_or_config=template_name or config_path,
)
if return_code != 0:
return
raise RuntimeError(f"Failed to build image {image_name}")
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
@ -286,8 +323,9 @@ def _run_stack_build_command_from_build_config(
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green")
return template_path
else:
_generate_run_config(build_config, build_dir, image_name)
return _generate_run_config(build_config, build_dir, image_name)
def _run_template_list_cmd() -> None:

View file

@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if
help="Print the dependencies for the stack only, without building the stack",
)
self.parser.add_argument(
"--run",
action="store_true",
default=False,
help="Run the stack after building using the same image type, name, and other applicable arguments",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies

View file

@ -1,46 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama stack configure",
description="Configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
help="Path to the build config file (e.g. ~/.llama/builds/<image_type>/<name>-build.yaml). For container, this could also be the name of the container image. ",
)
self.parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build/<image_type>/<name>-run.yaml",
)
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
self.parser.error(
"""
DEPRECATED! llama stack configure has been deprecated.
Please use llama stack run <path/to/run.yaml> instead.
Please see example run.yaml in /distributions folder.
"""
)

View file

@ -74,10 +74,6 @@ class StackRun(Subcommand):
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
import json
import subprocess
import yaml
from termcolor import cprint
@ -87,7 +83,7 @@ class StackRun(Subcommand):
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
if not args.config:
self.parser.error("Must specify a config file to run")
@ -125,64 +121,7 @@ class StackRun(Subcommand):
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if args.image_type == ImageType.container.value or config.container_image:
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
elif args.image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if not image_name:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
)
return
def get_conda_prefix(env_name):
# Conda "base" environment does not end with "base" in the
# prefix, so should be handled separately.
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
return envpath
return None
print(f"Using conda environment: {image_name}")
conda_prefix = get_conda_prefix(image_name)
if not conda_prefix:
cprint(
f"Conda environment {image_name} does not exist.",
color="red",
)
return
build_file = Path(conda_prefix) / "llamastack-build.yaml"
if not build_file.exists():
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
)
return
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
run_args = [
script,
image_name,
]
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
venv = args.image_name or current_venv
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
run_args = [
script,
venv,
]
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
@ -206,5 +145,4 @@ class StackRun(Subcommand):
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args)

View file

@ -10,7 +10,6 @@ from importlib.metadata import version
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .list_apis import StackListApis
from .list_providers import StackListProviders
from .run import StackRun
@ -37,7 +36,6 @@ class StackParser(Subcommand):
# Add sub-commands
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)

View file

@ -7,7 +7,6 @@
import importlib.resources
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Dict, List
@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [
]
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"
class ApiInput(BaseModel):
api: Api
provider: str

View file

@ -8,6 +8,8 @@
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
BUILD_PLATFORM=${BUILD_PLATFORM:-}
@ -32,7 +34,7 @@ container_base="$3"
build_file_path="$4"
host_build_dir="$5"
pip_dependencies="$6"
special_pip_deps="$7"
special_pip_deps="${7:-}"
# Define color codes
@ -106,26 +108,39 @@ fi
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
client_mount="/app/llama-stack-client-source"
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
install_local_package() {
local dir="$1"
local mount_point="$2"
local name="$3"
if [ ! -d "$dir" ]; then
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
exit 1
fi
# Install in editable format. We will mount the source code into the container
# so that changes will be reflected in the container without having to do a
# rebuild. This is just for development convenience.
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $LLAMA_STACK_DIR $stack_mount
COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $stack_mount
RUN uv pip install --no-cache -e $mount_point
EOF
}
if [ -n "$LLAMA_MODELS_DIR" ]; then
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
@ -134,6 +149,7 @@ RUN uv pip install fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF
@ -149,23 +165,6 @@ EOF
fi
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $LLAMA_MODELS_DIR $models_mount
EOF
fi
add_to_container << EOF
RUN uv pip uninstall llama-models
RUN uv pip install --no-cache $models_mount
EOF
fi
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
@ -198,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
fi
fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then

View file

@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
@ -24,8 +25,8 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
@ -34,8 +35,7 @@ special_pip_deps="$3"
set -euo pipefail
build_name="$1"
env_name="llamastack-$build_name"
env_name="$1"
pip_dependencies="$2"
# Define color codes
@ -74,9 +74,13 @@ run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
if [ -n "$UV_SYSTEM_PYTHON" ]; then
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment"
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
@ -90,6 +94,7 @@ run() {
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then

View file

@ -41,6 +41,7 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -52,19 +53,6 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
@ -230,12 +218,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}

View file

@ -52,6 +52,7 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
class VectorIORouter(VectorIO):
@ -158,6 +159,8 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []

View file

@ -1,67 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
exit 1
fi
env_name="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Process environment variables from --env arguments
env_vars=""
other_args=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
# collect environment variables so we can set them after activating the conda env
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
set -x
$CONDA_PREFIX/bin/python \
-m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -1,105 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
exit 1
fi
image_name="$1"
container_image="localhost/$image_name"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize other_args
other_args=""
# Process environment variables from --env arguments
env_vars=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
echo "env = $2"
if [[ -n "$2" ]]; then
env_vars="$env_vars -e $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
set -x
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
fi
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
fi
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
fi
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$LLAMA_STACK_DIR" ]; then
version_tag="dev"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
else
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
--env LLAMA_STACK_PORT=$port \
--entrypoint python \
$container_image:$version_tag \
-m llama_stack.distribution.server.server \
--yaml-config /app/config.yaml \
$other_args

View file

@ -0,0 +1,150 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
exit 1
fi
env_type="$1"
shift
env_path_or_name="$1"
container_image="localhost/$env_path_or_name"
shift
yaml_config="$1"
shift
port="$1"
shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Initialize env_vars as an string
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
;;
"conda")
if ! is_command_available conda; then
echo -e "${RED}Error: conda not found" >&2
exit 1
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_path_or_name"
PYTHON_BINARY="$CONDA_PREFIX/bin/python"
;;
*)
esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
fi
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
fi
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
fi
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$LLAMA_STACK_DIR" ]; then
version_tag="dev"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
else
if ! is_command_available jq; then
echo -e "${RED}Error: jq not found" >&2
exit 1
fi
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
--env LLAMA_STACK_PORT=$port \
--entrypoint python \
$container_image:$version_tag \
-m llama_stack.distribution.server.server \
--yaml-config /app/config.yaml \
$other_args
fi

View file

@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do
esac
done
echo "Using virtual environment: $venv_path"
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2

View file

@ -12,8 +12,81 @@ import signal
import subprocess
import sys
from termcolor import cprint
log = logging.getLogger(__name__)
import importlib
import json
from pathlib import Path
from llama_stack.distribution.utils.image_types import ImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = ""
if image_type == ImageType.container.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image
elif image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env
if not env_name:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
)
return
def get_conda_prefix(env_name):
# Conda "base" environment does not end with "base" in the
# prefix, so should be handled separately.
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
return envpath
return None
print(f"Using conda environment: {env_name}")
conda_prefix = get_conda_prefix(env_name)
if not conda_prefix:
cprint(
f"Conda environment {env_name} does not exist.",
color="red",
)
return
build_file = Path(conda_prefix) / "llamastack-build.yaml"
if not build_file.exists():
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
)
return
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
env_name = image_name or current_venv
if not env_name:
cprint(
"No current virtual environment detected, please specify a virtual environment name with --image-name",
color="red",
)
return
print(f"Using virtual environment: {env_name}")
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
run_args = [
script,
image_type,
env_name,
]
return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
@ -22,6 +95,19 @@ def run_with_pty(command):
return _run_with_pty_unix(command)
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def _run_with_pty_unix(command):

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"

View file

@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
@ -60,9 +62,13 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
if output_message.tool_calls and request.allow_turn_resume:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
@ -456,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
metadata=result.metadata,
)
],
),
@ -611,11 +743,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
if tool_call.tool_name in client_tools:
yield message
return
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -625,6 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -639,6 +768,23 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
@ -650,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
},
) as span:
tool_execution_start_time = datetime.now()
result_messages = await execute_tool_call_maybe(
tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
[message],
tool_call,
toolgroup_args,
tool_to_group,
)
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@ -675,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
@ -913,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
# When this changes, we can drop this assert
# Whether to call tools on each message and aggregate
# or aggregate and call tool once, reamins to be seen.
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
@ -946,14 +1092,7 @@ async def execute_tool_call_maybe(
**tool_call_args,
),
)
return [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=result.content,
)
]
return result
def _interpret_content_as_attachment(

View file

@ -11,8 +11,6 @@ import tempfile
import uuid
from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored
from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
@ -21,6 +19,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
@ -68,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents):
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
async def create_agent(
self,
@ -146,6 +140,7 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -155,6 +150,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -169,6 +165,34 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None

View file

@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl(
pass
async def register_model(self, model: Model) -> None:
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:

View file

@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
for c in chunks[: query_config.max_chunks]:
for c in chunks:
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(

View file

@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
pip_packages=["sentence-transformers"],
pip_packages=[
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
],
module="llama_stack.providers.inline.inference.sentence_transformers",
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
),

View file

@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[],
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference],

View file

@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import (
remote_provider_spec,
)
EMBEDDING_DEPS = [
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# this happens to work because special dependencies are always installed last
# so if there was a regular torch installed first, this would be ignored
# we need a better way to do this to identify potential conflicts, etc.
# for now, this lets us significantly reduce the size of the container which
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
pip_packages=["faiss-cpu"],
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
pip_packages=["faiss-cpu"],
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
provider_type="inline::sqlite-vec",
pip_packages=["sqlite-vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=["sqlite-vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
),
@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
pip_packages=["chromadb"],
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
),
@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="qdrant",
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
),

View file

@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -52,7 +52,7 @@ _MODEL_ENTRIES = [
provider_model_id="baai/bge-m3",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 1024,
"embedding_dimension": 1024,
"context_length": 8192,
},
),

View file

@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
@ -280,7 +281,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:
response = await self.client.ps()
@ -290,7 +294,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
return await self.register_helper.register_model(model)
return model
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:

View file

@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture:
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),

View file

@ -14,6 +14,7 @@ from llama_stack.apis.inference import (
ModelStore,
TextTruncation,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
EMBEDDING_MODELS = {}
@ -34,7 +35,9 @@ class SentenceTransformerEmbeddingMixin:
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(contents)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
)
return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":

View file

@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
if model.metadata.get("llama_model") is None:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
"Please specify a llama_model in metadata or use a supported model identifier"
)
llama_model = model.metadata.get("llama_model")
if llama_model is None:
return model
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
return model

View file

@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
else:
# specific tool
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
# llama3.2 and llama3.3 models follow the same tool prompt format
return ToolPromptFormat.python_list
else:
return ToolPromptFormat.json

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, TypeVar
from typing import Any, Callable, List, Optional, Protocol, TypeVar
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
T = TypeVar("T")
@dataclass
class WebMethod:
@ -22,6 +20,13 @@ class WebMethod:
raw_bytes_request_body: Optional[bool] = False
class HasWebMethod(Protocol):
__webmethod__: WebMethod
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
def webmethod(
route: Optional[str] = None,
method: Optional[str] = None,

View file

@ -11,7 +11,7 @@ import subprocess
import sys
from functools import partial
from pathlib import Path
from typing import Iterator
from typing import Iterable
from rich.progress import Progress, SpinnerColumn, TextColumn
@ -39,7 +39,7 @@ class ChangedPathTracker:
return self._changed_paths
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
return has_changes
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
try:
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)

View file

@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
pytest_args,
"-s",
"-v",
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
]
)

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::cerebras
- inline::sentence-transformers
safety:
- inline::llama-guard
vector_io:

View file

@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::cerebras"],
"inference": ["remote::cerebras", "inline::sentence-transformers"],
"safety": ["inline::llama-guard"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"agents": ["inline::meta-reference"],

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ci_tests import get_distribution_template # noqa: F401

View file

@ -0,0 +1,33 @@
version: '2'
distribution_spec:
description: Distribution for running e2e tests in CI
providers:
inference:
- remote::fireworks
- inline::sentence-transformers
vector_io:
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::fireworks", "inline::sentence-transformers"],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "ci-tests"
inference_provider = Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="fireworks",
metadata=m.metadata,
model_type=m.model_type,
)
for m in MODEL_ENTRIES
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running e2e tests in CI",
container_image=None,
template_path=None,
providers=providers,
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (
"",
"Fireworks API Key",
),
},
)

View file

@ -0,0 +1,169 @@
version: '2'
image_name: ci-tests
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -5,6 +5,7 @@ distribution_spec:
providers:
inference:
- remote::tgi
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::tgi"],
"inference": ["remote::tgi", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::fireworks
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -18,14 +18,14 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::fireworks"],
"inference": ["remote::fireworks", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::hf::serverless
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::hf::serverless"],
"inference": ["remote::hf::serverless", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -136,7 +136,7 @@ models:
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata:
embedding_dimensions: 1024
embedding_dimension: 1024
context_length: 8192
model_id: baai/bge-m3
provider_id: nvidia

View file

@ -5,7 +5,7 @@ distribution_spec:
inference:
- remote::ollama
vector_io:
- inline::faiss
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:

View file

@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::ollama"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
vector_io_provider_sqlite = Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
"inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite],
},
default_models=[inference_model, embedding_model],
default_models=[inference_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
embedding_provider,
],
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
"inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite],
"safety": [
Provider(
provider_id="llama-guard",

View file

@ -16,24 +16,11 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -16,19 +16,9 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: sqlite_vec
provider_type: inline::sqlite_vec
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety:
@ -97,12 +87,6 @@ models:
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []
datasets: []

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::vllm
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::vllm"],
"inference": ["remote::vllm", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::tgi
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::tgi"],
"inference": ["remote::tgi", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- remote::together
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::together"],
"inference": ["remote::together", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,6 +4,7 @@ distribution_spec:
providers:
inference:
- inline::vllm
- inline::sentence-transformers
vector_io:
- inline::faiss
- remote::chromadb

View file

@ -20,7 +20,7 @@ from llama_stack.templates.template import (
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["inline::vllm"],
"inference": ["inline::vllm", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llama_stack"
version = "0.1.3"
version = "0.1.4"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -26,8 +26,8 @@ dependencies = [
"httpx",
"huggingface-hub",
"jsonschema",
"llama-models>=0.1.3",
"llama-stack-client>=0.1.3",
"llama-models>=0.1.4",
"llama-stack-client>=0.1.4",
"prompt-toolkit",
"python-dotenv",
"pydantic>=2",
@ -158,3 +158,26 @@ ignore = [
"B007",
"B008",
]
[tool.mypy]
mypy_path = ["llama_stack"]
packages = ["llama_stack"]
disable_error_code = []
warn_return_any = true
# # honor excludes by not following there through imports
follow_imports = "silent"
exclude = [
# As we fix more and more of these, we should remove them from the list
"llama_stack/providers",
"llama_stack/distribution",
"llama_stack/apis",
"llama_stack/cli",
"llama_stack/models",
"llama_stack/strong_typing",
"llama_stack/templates",
]
[[tool.mypy.overrides]]
# packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["llama_models.*", "yaml", "fire"]
ignore_missing_imports = true

View file

@ -16,13 +16,13 @@ fsspec==2025.2.0
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.28.1
huggingface-hub==0.29.0
idna==3.10
jinja2==3.1.5
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
llama-models==0.1.3
llama-stack-client==0.1.3
llama-models==0.1.4
llama-stack-client==0.1.4
lxml==5.3.1
markdown-it-py==3.0.0
markupsafe==3.0.2

View file

@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
from llama_stack.apis.agents.agents import ToolChoice
from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice,
)
class TestClientTool(ClientTool):
@ -86,7 +90,6 @@ class TestClientTool(ClientTool):
def agent_config(llama_stack_client, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
available_shields = available_shields[:1]
print(f"Using shield: {available_shields}")
agent_config = AgentConfig(
model=text_model_id,
instructions="You are a helpful assistant",
@ -322,17 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client, agent_config):
data = [
("required", '{"type": "function"'),
("none", None),
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
]
client_tool = TestClientTool()
for tool_choice, expected_tool in data:
agent_config["tool_config"] = {"tool_choice": tool_choice}
agent_config["client_tools"] = [client_tool.get_tool_definition()]
def run_agent(tool_choice):
client_tool = TestClientTool()
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
test_agent_config = {
**agent_config,
"tool_config": {"tool_choice": tool_choice},
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -343,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config):
},
],
session_id=session_id,
stream=False,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
if expected_tool:
assert expected_tool in logs_str
else:
assert '{"type": "function"' not in logs_str
return [step for step in response.steps if step.step_type == "tool_execution"]
tool_execution_steps = run_agent("required")
assert len(tool_execution_steps) > 0
tool_execution_steps = run_agent("none")
assert len(tool_execution_steps) == 0
tool_execution_steps = run_agent("get_boiling_point")
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
# TODO: fix this flaky test
@ -378,7 +385,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
# can't tell a joke: "I don't have a function"
assert "function" in logs_str
@ -417,7 +423,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
assert "bicycle" in logs_str
response = agent.create_turn(
@ -432,7 +437,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
assert "-100" in logs_str
assert "get_boiling_point" in logs_str
@ -453,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config):
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
@ -484,15 +489,17 @@ def test_rag_agent(llama_stack_client, agent_config):
),
]
for prompt, expected_kw in user_prompts:
print(f"User> {prompt}")
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
stream=False,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_from_memory" in logs_str
assert expected_kw in logs_str.lower()
# rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
# document ids are present in metadata
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
assert expected_kw in response.output_message.content.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config):
@ -548,7 +555,6 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
]
for prompt, docs, tool_name in user_prompts:
print(f"User> {prompt}")
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],

View file

@ -42,28 +42,30 @@ def pytest_addoption(parser):
)
parser.addoption(
"--inference-model",
action="store",
default=TEXT_MODEL,
help="Specify the inference model to use for testing",
)
parser.addoption(
"--vision-inference-model",
action="store",
default=VISION_MODEL,
help="Specify the vision inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield model to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=TEXT_MODEL,
default=None,
help="Specify the embedding model to use for testing",
)
parser.addoption(
"--embedding-dimension",
type=int,
default=384,
help="Output dimensionality of the embedding model to use for testing",
)
@pytest.fixture(scope="session")
@ -78,7 +80,7 @@ def provider_data():
@pytest.fixture(scope="session")
def llama_stack_client(provider_data):
def llama_stack_client(provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
@ -95,25 +97,91 @@ def llama_stack_client(provider_data):
)
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
if text_model_id:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension},
)
return client
MODEL_SHORT_IDS = {
"meta-llama/Llama-3.1-8B-Instruct": "8B",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "11B",
"all-MiniLM-L6-v2": "MiniLM",
}
def get_short_id(value):
return MODEL_SHORT_IDS.get(value, value)
def pytest_generate_tests(metafunc):
params = []
values = []
id_parts = []
if "text_model_id" in metafunc.fixturenames:
metafunc.parametrize(
"text_model_id",
[metafunc.config.getoption("--inference-model")],
scope="session",
)
params.append("text_model_id")
val = metafunc.config.getoption("--inference-model")
values.append(val)
id_parts.append(f"txt={get_short_id(val)}")
if "vision_model_id" in metafunc.fixturenames:
metafunc.parametrize(
"vision_model_id",
[metafunc.config.getoption("--vision-inference-model")],
scope="session",
)
params.append("vision_model_id")
val = metafunc.config.getoption("--vision-inference-model")
values.append(val)
id_parts.append(f"vis={get_short_id(val)}")
if "embedding_model_id" in metafunc.fixturenames:
metafunc.parametrize(
"embedding_model_id",
[metafunc.config.getoption("--embedding-model")],
scope="session",
)
params.append("embedding_model_id")
val = metafunc.config.getoption("--embedding-model")
values.append(val)
if val is not None:
id_parts.append(f"emb={get_short_id(val)}")
if "embedding_dimension" in metafunc.fixturenames:
params.append("embedding_dimension")
val = metafunc.config.getoption("--embedding-dimension")
values.append(val)
if val != 384:
id_parts.append(f"dim={val}")
if params:
# Create a single test ID string
test_id = ":".join(id_parts)
metafunc.parametrize(params, [values], scope="session", ids=[test_id])

Some files were not shown because too many files have changed in this diff Show more