mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 10:24:31 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
b41afa5843
106 changed files with 2223 additions and 853 deletions
|
|
@ -30,6 +30,7 @@ repos:
|
||||||
rev: v0.9.4
|
rev: v0.9.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [ --fix ]
|
||||||
exclude: ^llama_stack/strong_typing/.*$
|
exclude: ^llama_stack/strong_typing/.*$
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
|
@ -45,23 +46,26 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
args: [
|
args: [
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
files: ^pyproject\.toml$
|
files: ^pyproject\.toml$
|
||||||
- id: uv-sync
|
- id: uv-sync
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
# rev: v1.14.0
|
rev: v1.15.0
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# additional_dependencies:
|
additional_dependencies:
|
||||||
# - types-requests
|
- uv==0.6.2
|
||||||
# - types-setuptools
|
- mypy
|
||||||
# - pydantic
|
- pytest
|
||||||
# args: [--ignore-missing-imports]
|
- rich
|
||||||
|
- types-requests
|
||||||
|
- pydantic
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
# - repo: https://github.com/jsh9/pydoclint
|
# - repo: https://github.com/jsh9/pydoclint
|
||||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||||
|
|
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
3.10
|
||||||
|
|
@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
|
||||||
$ cd llama-stack/docs
|
$ cd llama-stack/docs
|
||||||
$ uv sync --extra docs
|
$ uv sync --extra docs
|
||||||
|
|
||||||
|
# This rebuilds the documentation pages.
|
||||||
|
$ uv run make html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
$ make html
|
$ uv run sphinx-autobuild source build/html --write-all
|
||||||
$ uv run sphinx-autobuild source build/html
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ include distributions/dependencies.json
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
include llama_stack/cli/scripts/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_stack/templates/*/*.yaml
|
include llama_stack/templates/*/*.yaml
|
||||||
|
include llama_stack/providers/tests/test_cases/*.json
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"cerebras": [
|
"cerebras": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -68,6 +66,41 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"ci-tests": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"fireworks-ai",
|
||||||
|
"httpx",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"dell": [
|
"dell": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -170,9 +203,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"hf-serverless": [
|
"hf-serverless": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|
@ -247,9 +278,7 @@
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"zmq",
|
"zmq"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"meta-reference-quantized-gpu": [
|
"meta-reference-quantized-gpu": [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
|
@ -290,9 +319,7 @@
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"zmq",
|
"zmq"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"nvidia": [
|
"nvidia": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -323,9 +350,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|
@ -335,7 +360,6 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
|
@ -356,11 +380,10 @@
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"remote-vllm": [
|
"remote-vllm": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -423,9 +446,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"tgi": [
|
"tgi": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|
|
||||||
191
docs/_static/llama-stack-spec.html
vendored
191
docs/_static/llama-stack-spec.html
vendored
|
|
@ -2315,6 +2315,70 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "agent_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the agent to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "session_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the session to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "turn_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the turn to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
|
@ -4226,6 +4290,9 @@
|
||||||
},
|
},
|
||||||
"tool_config": {
|
"tool_config": {
|
||||||
"$ref": "#/components/schemas/ToolConfig"
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
|
},
|
||||||
|
"allow_turn_resume": {
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
@ -4454,6 +4521,31 @@
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
@ -4612,6 +4704,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
|
@ -4621,7 +4716,8 @@
|
||||||
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
||||||
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
||||||
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
||||||
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
|
||||||
|
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -4784,6 +4880,25 @@
|
||||||
"title": "AgentTurnResponseStreamChunk",
|
"title": "AgentTurnResponseStreamChunk",
|
||||||
"description": "streamed agent turn completion response."
|
"description": "streamed agent turn completion response."
|
||||||
},
|
},
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"event_type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "turn_awaiting_input",
|
||||||
|
"default": "turn_awaiting_input"
|
||||||
|
},
|
||||||
|
"turn": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"event_type",
|
||||||
|
"turn"
|
||||||
|
],
|
||||||
|
"title": "AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
|
},
|
||||||
"AgentTurnResponseTurnCompletePayload": {
|
"AgentTurnResponseTurnCompletePayload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -6656,6 +6771,31 @@
|
||||||
},
|
},
|
||||||
"error_code": {
|
"error_code": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
@ -7505,9 +7645,37 @@
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"metadata"
|
||||||
|
],
|
||||||
"title": "RAGQueryResult"
|
"title": "RAGQueryResult"
|
||||||
},
|
},
|
||||||
"QueryChunksRequest": {
|
"QueryChunksRequest": {
|
||||||
|
|
@ -8046,6 +8214,27 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"ResumeAgentTurnRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_responses": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||||
|
},
|
||||||
|
"description": "The tool call responses to resume the turn with."
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to stream the response."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"tool_responses"
|
||||||
|
],
|
||||||
|
"title": "ResumeAgentTurnRequest"
|
||||||
|
},
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
|
||||||
114
docs/_static/llama-stack-spec.yaml
vendored
114
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -1401,6 +1401,53 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||||
|
objects.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: >-
|
||||||
|
Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client
|
||||||
|
side tool calls, this endpoint can be used to submit the outputs from the
|
||||||
|
tool calls once they are ready.
|
||||||
|
parameters:
|
||||||
|
- name: agent_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the agent to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: session_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the session to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: turn_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the turn to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ResumeAgentTurnRequest'
|
||||||
|
required: true
|
||||||
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -2740,6 +2787,8 @@ components:
|
||||||
$ref: '#/components/schemas/AgentTool'
|
$ref: '#/components/schemas/AgentTool'
|
||||||
tool_config:
|
tool_config:
|
||||||
$ref: '#/components/schemas/ToolConfig'
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
allow_turn_resume:
|
||||||
|
type: boolean
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- messages
|
- messages
|
||||||
|
|
@ -2896,6 +2945,16 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- call_id
|
- call_id
|
||||||
|
|
@ -2992,6 +3051,7 @@ components:
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: event_type
|
propertyName: event_type
|
||||||
mapping:
|
mapping:
|
||||||
|
|
@ -3000,6 +3060,7 @@ components:
|
||||||
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
AgentTurnResponseStepCompletePayload:
|
AgentTurnResponseStepCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -3106,6 +3167,21 @@ components:
|
||||||
- event
|
- event
|
||||||
title: AgentTurnResponseStreamChunk
|
title: AgentTurnResponseStreamChunk
|
||||||
description: streamed agent turn completion response.
|
description: streamed agent turn completion response.
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
event_type:
|
||||||
|
type: string
|
||||||
|
const: turn_awaiting_input
|
||||||
|
default: turn_awaiting_input
|
||||||
|
turn:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- event_type
|
||||||
|
- turn
|
||||||
|
title: >-
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload
|
||||||
AgentTurnResponseTurnCompletePayload:
|
AgentTurnResponseTurnCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -4315,6 +4391,16 @@ components:
|
||||||
type: string
|
type: string
|
||||||
error_code:
|
error_code:
|
||||||
type: integer
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
|
|
@ -4888,7 +4974,19 @@ components:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- metadata
|
||||||
title: RAGQueryResult
|
title: RAGQueryResult
|
||||||
QueryChunksRequest:
|
QueryChunksRequest:
|
||||||
type: object
|
type: object
|
||||||
|
|
@ -5205,6 +5303,22 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
ResumeAgentTurnRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_responses:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ToolResponseMessage'
|
||||||
|
description: >-
|
||||||
|
The tool call responses to resume the turn with.
|
||||||
|
stream:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to stream the response.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- tool_responses
|
||||||
|
title: ResumeAgentTurnRequest
|
||||||
RunEvalRequest:
|
RunEvalRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -86,10 +86,8 @@
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!apt-get install -y bubblewrap\n",
|
"!apt-get install -y bubblewrap\n",
|
||||||
"import os\n",
|
|
||||||
"os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n",
|
|
||||||
"!pip install uv\n",
|
"!pip install uv\n",
|
||||||
"!uv pip install llama-stack"
|
"!uv pip install llama-stack --system"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -128,7 +126,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"# This will build all the dependencies you will need\n",
|
"# This will build all the dependencies you will need\n",
|
||||||
"!llama stack build --template together --image-type venv"
|
"!llama stack build --template together --image-type venv --image-name __system__"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -3632,7 +3630,7 @@
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "master",
|
"display_name": "toolchain",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"!llama stack build --template together --image-type venv"
|
"!llama stack build --template together --image-type venv --image-name __system__"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
||||||
## API Providers
|
## API Providers
|
||||||
|
|
||||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.),
|
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
|
|
@ -33,7 +33,7 @@ Providers come in two flavors:
|
||||||
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||||
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||||
|
|
||||||
Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally.
|
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
from docutils import nodes
|
from docutils import nodes
|
||||||
|
|
||||||
project = "llama-stack"
|
project = "llama-stack"
|
||||||
copyright = "2024, Meta"
|
copyright = "2025, Meta"
|
||||||
author = "Meta"
|
author = "Meta"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
||||||
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
|
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
|
||||||
|
|
||||||
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
|
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
|
||||||
- {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama))
|
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
||||||
|
|
||||||
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
||||||
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
||||||
|
|
@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
||||||
- [Android](ondevice_distro/android_sdk)
|
- [Android](ondevice_distro/android_sdk)
|
||||||
|
|
||||||
|
|
||||||
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).**
|
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).**
|
||||||
|
|
||||||
### Distribution Details
|
### Distribution Details
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::cerebras` |
|
| inference | `remote::cerebras`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::tgi` |
|
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::fireworks` |
|
| inference | `remote::fireworks`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::vllm` |
|
| inference | `remote::vllm`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::tgi` |
|
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::together` |
|
| inference | `remote::together`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
||||||
| **Provider** | **Environments** |
|
| **Provider** | **Environments** |
|
||||||
| :----: | :----: |
|
| :----: | :----: |
|
||||||
| FAISS | Single Node |
|
| FAISS | Single Node |
|
||||||
|
| SQLite-Vec| Single Node |
|
||||||
| Chroma | Hosted and Single Node |
|
| Chroma | Hosted and Single Node |
|
||||||
| Postgres (PGVector) | Hosted and Single Node |
|
| Postgres (PGVector) | Hosted and Single Node |
|
||||||
| Weaviate | Hosted |
|
| Weaviate | Hosted |
|
||||||
|
|
@ -88,6 +89,7 @@ self
|
||||||
introduction/index
|
introduction/index
|
||||||
getting_started/index
|
getting_started/index
|
||||||
concepts/index
|
concepts/index
|
||||||
|
providers/index
|
||||||
distributions/index
|
distributions/index
|
||||||
distributions/selection
|
distributions/selection
|
||||||
building_applications/index
|
building_applications/index
|
||||||
|
|
|
||||||
59
docs/source/providers/index.md
Normal file
59
docs/source/providers/index.md
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Providers Overview
|
||||||
|
|
||||||
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
|
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||||
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
|
Providers come in two flavors:
|
||||||
|
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||||
|
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||||
|
|
||||||
|
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
|
|
||||||
|
## Agents
|
||||||
|
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||||
|
|
||||||
|
## DatasetIO
|
||||||
|
Interfaces with datasets and data loaders.
|
||||||
|
|
||||||
|
## Eval
|
||||||
|
Generates outputs (via Inference or Agents) and perform scoring.
|
||||||
|
|
||||||
|
## Inference
|
||||||
|
Runs inference with an LLM.
|
||||||
|
|
||||||
|
## Post Training
|
||||||
|
Fine-tunes a model.
|
||||||
|
|
||||||
|
## Safety
|
||||||
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
## Scoring
|
||||||
|
Evaluates the outputs of the system.
|
||||||
|
|
||||||
|
## Telemetry
|
||||||
|
Collects telemetry data from the system.
|
||||||
|
|
||||||
|
## Tool Runtime
|
||||||
|
Is associated with the ToolGroup resouces.
|
||||||
|
|
||||||
|
## Vector IO
|
||||||
|
|
||||||
|
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
|
||||||
|
Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector
|
||||||
|
io and database are used to store and retrieve documents for retrieval.
|
||||||
|
|
||||||
|
#### Vector IO Providers
|
||||||
|
The following providers (i.e., databases) are available for Vector IO:
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
vector_io/faiss
|
||||||
|
vector_io/sqlite-vec
|
||||||
|
vector_io/chromadb
|
||||||
|
vector_io/pgvector
|
||||||
|
vector_io/qdrant
|
||||||
|
vector_io/weaviate
|
||||||
|
```
|
||||||
36
docs/source/providers/vector_io/chromadb.md
Normal file
36
docs/source/providers/vector_io/chromadb.md
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Chroma
|
||||||
|
|
||||||
|
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||||
|
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
Chroma supports:
|
||||||
|
- Store embeddings and their metadata
|
||||||
|
- Vector search
|
||||||
|
- Full-text search
|
||||||
|
- Document storage
|
||||||
|
- Metadata filtering
|
||||||
|
- Multi-modal retrieval
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Chrome in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use chroma.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install chroma using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install chromadb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||||
33
docs/source/providers/vector_io/faiss.md
Normal file
33
docs/source/providers/vector_io/faiss.md
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Faiss
|
||||||
|
|
||||||
|
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Lightweight and easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Faiss in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Faiss using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install faiss-cpu
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
|
||||||
|
more details about Faiss in general.
|
||||||
31
docs/source/providers/vector_io/pgvector.md
Normal file
31
docs/source/providers/vector_io/pgvector.md
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Postgres PGVector
|
||||||
|
|
||||||
|
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use PGVector in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install PGVector using docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull pgvector/pgvector:pg17
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||||
31
docs/source/providers/vector_io/qdrant.md
Normal file
31
docs/source/providers/vector_io/qdrant.md
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Qdrant
|
||||||
|
|
||||||
|
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Qdrant in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Qdrant using docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull qdrant/qdrant
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||||
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# SQLite-Vec
|
||||||
|
|
||||||
|
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly within an SQLite database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Lightweight and easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use SQLite-Vec in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install SQLite-Vec using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sqlite-vec
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
||||||
33
docs/source/providers/vector_io/weaviate.md
Normal file
33
docs/source/providers/vector_io/weaviate.md
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Weaviate
|
||||||
|
|
||||||
|
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||||
|
It allows you to store and query vectors directly within a Weaviate database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
Weaviate supports:
|
||||||
|
- Store embeddings and their metadata
|
||||||
|
- Vector search
|
||||||
|
- Full-text search
|
||||||
|
- Hybrid search
|
||||||
|
- Document storage
|
||||||
|
- Metadata filtering
|
||||||
|
- Multi-modal retrieval
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Weaviate in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use chroma.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||||
|
|
@ -171,7 +171,7 @@ The `llama model` command helps you explore the model’s interface.
|
||||||
llama model --help
|
llama model --help
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
usage: llama model [-h] {download,list,prompt-format,describe} ...
|
usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ...
|
||||||
|
|
||||||
Work with llama models
|
Work with llama models
|
||||||
|
|
||||||
|
|
@ -179,15 +179,15 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
||||||
model_subcommands:
|
model_subcommands:
|
||||||
{download,list,prompt-format,describe}
|
{download,list,prompt-format,describe,verify-download,remove}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Describe
|
||||||
|
|
||||||
You can use the describe command to know more about a model:
|
You can use the describe command to know more about a model:
|
||||||
```
|
```
|
||||||
llama model describe -m Llama3.2-3B-Instruct
|
llama model describe -m Llama3.2-3B-Instruct
|
||||||
```
|
```
|
||||||
### Describe
|
|
||||||
|
|
||||||
```
|
```
|
||||||
+-----------------------------+----------------------------------+
|
+-----------------------------+----------------------------------+
|
||||||
| Model | Llama3.2-3B-Instruct |
|
| Model | Llama3.2-3B-Instruct |
|
||||||
|
|
@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
||||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||||
|
|
||||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||||
|
|
||||||
|
### Remove model
|
||||||
|
You can run `llama model remove` to remove unecessary model:
|
||||||
|
|
||||||
|
```
|
||||||
|
llama model remove -m Llama-Guard-3-8B-int8
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
turn_start = "turn_start"
|
turn_start = "turn_start"
|
||||||
turn_complete = "turn_complete"
|
turn_complete = "turn_complete"
|
||||||
|
turn_awaiting_input = "turn_awaiting_input"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||||
|
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||||
|
)
|
||||||
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = register_schema(
|
AgentTurnResponseEventPayload = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[
|
Union[
|
||||||
|
|
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
],
|
],
|
||||||
|
|
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
|
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||||
|
allow_turn_resume: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: str
|
||||||
|
tool_responses: List[ToolResponseMessage]
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
|
|
@ -333,8 +355,34 @@ class Agents(Protocol):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||||
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to resume.
|
||||||
|
:param session_id: The ID of the session to resume.
|
||||||
|
:param turn_id: The ID of the turn to resume.
|
||||||
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
|
:param stream: Whether to stream the response.
|
||||||
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
|
||||||
|
|
@ -91,15 +91,18 @@ ParamType = register_schema(
|
||||||
name="ParamType",
|
name="ParamType",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
# TODO: recursive definition of ParamType in these containers
|
# TODO: recursive definition of ParamType in these containers
|
||||||
# will cause infinite recursion in OpenAPI generation script
|
# will cause infinite recursion in OpenAPI generation script
|
||||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||||
# ArrayType.model_rebuild()
|
ArrayType.model_rebuild()
|
||||||
# ObjectType.model_rebuild()
|
ObjectType.model_rebuild()
|
||||||
# UnionType.model_rebuild()
|
UnionType.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
# class CustomType(BaseModel):
|
class CustomType(BaseModel):
|
||||||
# type: Literal["custom"] = "custom"
|
pylint: disable=syntax-error
|
||||||
# validator_class: str
|
type: Literal["custom"] = "custom"
|
||||||
|
validator_class: str
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: Union[BuiltinTool, str]
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: Optional[InterleavedContent] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
error_code: Optional[int] = None
|
error_code: Optional[int] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,13 @@ def _get_model_size(model_dir):
|
||||||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_model_descriptor(model):
|
||||||
|
for m in all_registered_models():
|
||||||
|
if model == m.descriptor().replace(":", "-"):
|
||||||
|
return str(m.descriptor())
|
||||||
|
return str(model)
|
||||||
|
|
||||||
|
|
||||||
def _run_model_list_downloaded_cmd() -> None:
|
def _run_model_list_downloaded_cmd() -> None:
|
||||||
headers = ["Model", "Size", "Modified Time"]
|
headers = ["Model", "Size", "Modified Time"]
|
||||||
|
|
||||||
|
|
@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None:
|
||||||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
model,
|
_convert_to_model_descriptor(model),
|
||||||
model_size,
|
model_size,
|
||||||
modified_time,
|
modified_time,
|
||||||
]
|
]
|
||||||
|
|
@ -68,6 +75,13 @@ class ModelList(Subcommand):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="List the downloaded models",
|
help="List the downloaded models",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--search",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Search for the input string as a substring in the model descriptor(ID)",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
|
@ -87,15 +101,19 @@ class ModelList(Subcommand):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
descriptor = model.descriptor()
|
descriptor = model.descriptor()
|
||||||
rows.append(
|
if not args.search or args.search.lower() in descriptor.lower():
|
||||||
[
|
rows.append(
|
||||||
descriptor,
|
[
|
||||||
model.huggingface_repo,
|
descriptor,
|
||||||
f"{model.max_seq_length // 1024}K",
|
model.huggingface_repo,
|
||||||
]
|
f"{model.max_seq_length // 1024}K",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(rows) == 0:
|
||||||
|
print(f"Did not find any model matching `{args.search}`.")
|
||||||
|
else:
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
)
|
)
|
||||||
print_table(
|
|
||||||
rows,
|
|
||||||
headers,
|
|
||||||
separate_rows=True,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
||||||
from llama_stack.cli.model.download import ModelDownload
|
from llama_stack.cli.model.download import ModelDownload
|
||||||
from llama_stack.cli.model.list import ModelList
|
from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
|
from llama_stack.cli.model.remove import ModelRemove
|
||||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
|
||||||
ModelPromptFormat.create(subparsers)
|
ModelPromptFormat.create(subparsers)
|
||||||
ModelDescribe.create(subparsers)
|
ModelDescribe.create(subparsers)
|
||||||
ModelVerifyDownload.create(subparsers)
|
ModelVerifyDownload.create(subparsers)
|
||||||
|
ModelRemove.create(subparsers)
|
||||||
|
|
|
||||||
67
llama_stack/cli/model/remove.py
Normal file
67
llama_stack/cli/model/remove.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRemove(Subcommand):
|
||||||
|
"""Remove the downloaded llama model"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"remove",
|
||||||
|
prog="llama model remove",
|
||||||
|
description="Remove the downloaded llama model",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_model_remove_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--model",
|
||||||
|
required=True,
|
||||||
|
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Used to forcefully remove the llama model from the storage without further confirmation",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
|
||||||
|
prompt_guard = prompt_guard_model_sku()
|
||||||
|
if args.model == prompt_guard.model_id:
|
||||||
|
model = prompt_guard
|
||||||
|
else:
|
||||||
|
model = resolve_model(args.model)
|
||||||
|
|
||||||
|
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
|
||||||
|
|
||||||
|
if model is None or not os.path.isdir(model_path):
|
||||||
|
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.force:
|
||||||
|
shutil.rmtree(model_path)
|
||||||
|
print(f"{args.model} removed.")
|
||||||
|
else:
|
||||||
|
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||||
|
shutil.rmtree(model_path)
|
||||||
|
print(f"{args.model} removed.")
|
||||||
|
else:
|
||||||
|
print("Removal aborted.")
|
||||||
|
|
@ -9,6 +9,7 @@ import importlib.resources
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -23,10 +24,10 @@ from termcolor import cprint
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.build import (
|
from llama_stack.distribution.build import (
|
||||||
SERVER_DEPENDENCIES,
|
SERVER_DEPENDENCIES,
|
||||||
ImageType,
|
|
||||||
build_image,
|
build_image,
|
||||||
get_provider_dependencies,
|
get_provider_dependencies,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BuildConfig,
|
BuildConfig,
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
|
|
@ -37,6 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
|
@ -59,8 +62,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if args.list_templates:
|
if args.list_templates:
|
||||||
return _run_template_list_cmd()
|
return _run_template_list_cmd()
|
||||||
|
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
if args.image_type == "venv":
|
||||||
image_name = args.image_name or current_conda_env
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
|
image_name = args.image_name or current_venv
|
||||||
|
if not image_name and in_notebook():
|
||||||
|
image_name = "__system__"
|
||||||
|
elif args.image_type == "conda":
|
||||||
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
|
image_name = args.image_name or current_conda_env
|
||||||
|
else:
|
||||||
|
image_name = args.image_name
|
||||||
|
|
||||||
if args.template:
|
if args.template:
|
||||||
available_templates = available_templates_specs()
|
available_templates = available_templates_specs()
|
||||||
|
|
@ -69,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
if args.image_type:
|
if args.image_type:
|
||||||
build_config.image_type = args.image_type
|
build_config.image_type = args.image_type
|
||||||
|
|
@ -78,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
elif not args.config and not args.template:
|
elif not args.config and not args.template:
|
||||||
name = prompt(
|
name = prompt(
|
||||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||||
|
|
@ -159,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||||
cprint(
|
cprint(
|
||||||
"Please specify --image-name when building a container from a config file",
|
"Please specify --image-name when building a container from a config file",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
if args.print_deps_only:
|
if args.print_deps_only:
|
||||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||||
|
|
@ -177,19 +188,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
print(f"uv pip install {special_dep}")
|
print(f"uv pip install {special_dep}")
|
||||||
return
|
return
|
||||||
|
|
||||||
_run_stack_build_command_from_build_config(
|
try:
|
||||||
build_config,
|
run_config = _run_stack_build_command_from_build_config(
|
||||||
image_name=image_name,
|
build_config,
|
||||||
config_path=args.config,
|
image_name=image_name,
|
||||||
template_name=args.template,
|
config_path=args.config,
|
||||||
)
|
template_name=args.template,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (Exception, RuntimeError) as exc:
|
||||||
|
cprint(
|
||||||
|
f"Error building stack: {exc}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if run_config is None:
|
||||||
|
cprint(
|
||||||
|
"Run config path is empty",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.run:
|
||||||
|
run_config = Path(run_config)
|
||||||
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
|
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||||
|
run_with_pty(run_args)
|
||||||
|
|
||||||
|
|
||||||
def _generate_run_config(
|
def _generate_run_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
"""
|
"""
|
||||||
|
|
@ -239,6 +272,7 @@ def _generate_run_config(
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
|
|
@ -246,7 +280,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: Optional[str] = None,
|
image_name: Optional[str] = None,
|
||||||
template_name: Optional[str] = None,
|
template_name: Optional[str] = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: Optional[str] = None,
|
||||||
) -> None:
|
) -> str:
|
||||||
if build_config.image_type == ImageType.container.value:
|
if build_config.image_type == ImageType.container.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
image_name = f"distribution-{template_name}"
|
image_name = f"distribution-{template_name}"
|
||||||
|
|
@ -256,6 +290,9 @@ def _run_stack_build_command_from_build_config(
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a conda image")
|
raise ValueError("Please specify an image name when building a conda image")
|
||||||
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
|
if not image_name:
|
||||||
|
raise ValueError("Please specify an image name when building a venv image")
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
build_dir = DISTRIBS_BASE_DIR / template_name
|
build_dir = DISTRIBS_BASE_DIR / template_name
|
||||||
|
|
@ -276,7 +313,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
template_or_config=template_name or config_path,
|
template_or_config=template_name or config_path,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
return
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
# copy run.yaml from template to build_dir instead of generating it again
|
# copy run.yaml from template to build_dir instead of generating it again
|
||||||
|
|
@ -286,8 +323,9 @@ def _run_stack_build_command_from_build_config(
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
|
return template_path
|
||||||
else:
|
else:
|
||||||
_generate_run_config(build_config, build_dir, image_name)
|
return _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
|
|
||||||
def _run_template_list_cmd() -> None:
|
def _run_template_list_cmd() -> None:
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if
|
||||||
help="Print the dependencies for the stack only, without building the stack",
|
help="Print the dependencies for the stack only, without building the stack",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--run",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
# can be fast to load and reduces dependencies
|
# can be fast to load and reduces dependencies
|
||||||
|
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
|
|
||||||
class StackConfigure(Subcommand):
|
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"configure",
|
|
||||||
prog="llama stack configure",
|
|
||||||
description="Configure a llama stack distribution",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_stack_configure_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
self.parser.add_argument(
|
|
||||||
"config",
|
|
||||||
type=str,
|
|
||||||
help="Path to the build config file (e.g. ~/.llama/builds/<image_type>/<name>-build.yaml). For container, this could also be the name of the container image. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=str,
|
|
||||||
help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build/<image_type>/<name>-run.yaml",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
self.parser.error(
|
|
||||||
"""
|
|
||||||
DEPRECATED! llama stack configure has been deprecated.
|
|
||||||
Please use llama stack run <path/to/run.yaml> instead.
|
|
||||||
Please see example run.yaml in /distributions folder.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
@ -74,10 +74,6 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import importlib.resources
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
@ -87,7 +83,7 @@ class StackRun(Subcommand):
|
||||||
BUILDS_BASE_DIR,
|
BUILDS_BASE_DIR,
|
||||||
DISTRIBS_BASE_DIR,
|
DISTRIBS_BASE_DIR,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
self.parser.error("Must specify a config file to run")
|
self.parser.error("Must specify a config file to run")
|
||||||
|
|
@ -125,64 +121,7 @@ class StackRun(Subcommand):
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if args.image_type == ImageType.container.value or config.container_image:
|
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
|
||||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
|
||||||
run_args = [script, image_name]
|
|
||||||
elif args.image_type == ImageType.conda.value:
|
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
|
||||||
image_name = args.image_name or current_conda_env
|
|
||||||
if not image_name:
|
|
||||||
cprint(
|
|
||||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_conda_prefix(env_name):
|
|
||||||
# Conda "base" environment does not end with "base" in the
|
|
||||||
# prefix, so should be handled separately.
|
|
||||||
if env_name == "base":
|
|
||||||
return os.environ.get("CONDA_PREFIX")
|
|
||||||
# Get conda environments info
|
|
||||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
|
||||||
envs = conda_env_info["envs"]
|
|
||||||
for envpath in envs:
|
|
||||||
if envpath.endswith(env_name):
|
|
||||||
return envpath
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Using conda environment: {image_name}")
|
|
||||||
conda_prefix = get_conda_prefix(image_name)
|
|
||||||
if not conda_prefix:
|
|
||||||
cprint(
|
|
||||||
f"Conda environment {image_name} does not exist.",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
|
||||||
if not build_file.exists():
|
|
||||||
cprint(
|
|
||||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
|
||||||
run_args = [
|
|
||||||
script,
|
|
||||||
image_name,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# else must be venv since that is the only valid option left.
|
|
||||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
|
||||||
venv = args.image_name or current_venv
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
|
||||||
run_args = [
|
|
||||||
script,
|
|
||||||
venv,
|
|
||||||
]
|
|
||||||
|
|
||||||
run_args.extend([str(config_file), str(args.port)])
|
run_args.extend([str(config_file), str(args.port)])
|
||||||
if args.disable_ipv6:
|
if args.disable_ipv6:
|
||||||
|
|
@ -206,5 +145,4 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if args.tls_keyfile and args.tls_certfile:
|
if args.tls_keyfile and args.tls_certfile:
|
||||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||||
|
|
||||||
run_with_pty(run_args)
|
run_with_pty(run_args)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from importlib.metadata import version
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .configure import StackConfigure
|
|
||||||
from .list_apis import StackListApis
|
from .list_apis import StackListApis
|
||||||
from .list_providers import StackListProviders
|
from .list_providers import StackListProviders
|
||||||
from .run import StackRun
|
from .run import StackRun
|
||||||
|
|
@ -37,7 +36,6 @@ class StackParser(Subcommand):
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
StackBuild.create(subparsers)
|
StackBuild.create(subparsers)
|
||||||
StackConfigure.create(subparsers)
|
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ImageType(Enum):
|
|
||||||
container = "container"
|
|
||||||
conda = "conda"
|
|
||||||
venv = "venv"
|
|
||||||
|
|
||||||
|
|
||||||
class ApiInput(BaseModel):
|
class ApiInput(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
provider: str
|
provider: str
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
|
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
PYPI_VERSION=${PYPI_VERSION:-}
|
PYPI_VERSION=${PYPI_VERSION:-}
|
||||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||||
|
|
@ -32,7 +34,7 @@ container_base="$3"
|
||||||
build_file_path="$4"
|
build_file_path="$4"
|
||||||
host_build_dir="$5"
|
host_build_dir="$5"
|
||||||
pip_dependencies="$6"
|
pip_dependencies="$6"
|
||||||
special_pip_deps="$7"
|
special_pip_deps="${7:-}"
|
||||||
|
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
|
|
@ -106,26 +108,39 @@ fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
models_mount="/app/llama-models-source"
|
models_mount="/app/llama-models-source"
|
||||||
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
install_local_package() {
|
||||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
local dir="$1"
|
||||||
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
|
local mount_point="$2"
|
||||||
|
local name="$3"
|
||||||
|
|
||||||
|
if [ ! -d "$dir" ]; then
|
||||||
|
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install in editable format. We will mount the source code into the container
|
|
||||||
# so that changes will be reflected in the container without having to do a
|
|
||||||
# rebuild. This is just for development convenience.
|
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
COPY $LLAMA_STACK_DIR $stack_mount
|
COPY $dir $mount_point
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache -e $stack_mount
|
RUN uv pip install --no-cache -e $mount_point
|
||||||
EOF
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
|
||||||
else
|
else
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
|
|
@ -134,6 +149,7 @@ RUN uv pip install fastapi libcst
|
||||||
EOF
|
EOF
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||||
|
--index-strategy unsafe-best-match \
|
||||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
|
@ -149,23 +165,6 @@ EOF
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
|
||||||
add_to_container << EOF
|
|
||||||
COPY $LLAMA_MODELS_DIR $models_mount
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
add_to_container << EOF
|
|
||||||
RUN uv pip uninstall llama-models
|
|
||||||
RUN uv pip install --no-cache $models_mount
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
|
|
@ -198,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||||
fi
|
fi
|
||||||
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
|
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||||
|
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
|
|
@ -24,8 +25,8 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 3 ]; then
|
if [ "$#" -lt 2 ]; then
|
||||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
@ -34,8 +35,7 @@ special_pip_deps="$3"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
build_name="$1"
|
env_name="$1"
|
||||||
env_name="llamastack-$build_name"
|
|
||||||
pip_dependencies="$2"
|
pip_dependencies="$2"
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
|
|
@ -74,9 +74,13 @@ run() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
local pip_dependencies="$2"
|
local pip_dependencies="$2"
|
||||||
local special_pip_deps="$3"
|
local special_pip_deps="$3"
|
||||||
|
|
||||||
if [ -n "$UV_SYSTEM_PYTHON" ]; then
|
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||||
echo "Installing dependencies in system Python environment"
|
echo "Installing dependencies in system Python environment"
|
||||||
|
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
|
||||||
|
export UV_SYSTEM_PYTHON=1
|
||||||
|
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||||
|
echo "Virtual environment $env_name is already active"
|
||||||
else
|
else
|
||||||
echo "Using virtual environment $env_name"
|
echo "Using virtual environment $env_name"
|
||||||
uv venv "$env_name"
|
uv venv "$env_name"
|
||||||
|
|
@ -90,6 +94,7 @@ run() {
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
# we are building a command line so word splitting is expected
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
|
--index-strategy unsafe-best-match \
|
||||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ from llama_stack.distribution.stack import (
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
|
|
@ -52,19 +53,6 @@ logger = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def in_notebook():
|
|
||||||
try:
|
|
||||||
from IPython import get_ipython
|
|
||||||
|
|
||||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
|
||||||
return False
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
return value.value
|
return value.value
|
||||||
|
|
@ -230,12 +218,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
console = Console()
|
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
console = Console()
|
||||||
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
# Redact sensitive information before printing
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
endpoints = get_all_api_endpoints()
|
||||||
endpoint_impls = {}
|
endpoint_impls = {}
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
@ -158,6 +159,8 @@ class InferenceRouter(Inference):
|
||||||
params["tool_prompt_format"] = tool_prompt_format
|
params["tool_prompt_format"] = tool_prompt_format
|
||||||
tool_config = ToolConfig(**params)
|
tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
||||||
|
|
||||||
tools = tools or []
|
tools = tools or []
|
||||||
if tool_config.tool_choice == ToolChoice.none:
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
tools = []
|
tools = []
|
||||||
|
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
RED='\033[0;31m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
|
||||||
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
env_name="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
yaml_config="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
port="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
# Process environment variables from --env arguments
|
|
||||||
env_vars=""
|
|
||||||
other_args=""
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case "$1" in
|
|
||||||
--env)
|
|
||||||
|
|
||||||
if [[ -n "$2" ]]; then
|
|
||||||
# collect environment variables so we can set them after activating the conda env
|
|
||||||
env_vars="$env_vars --env $2"
|
|
||||||
shift 2
|
|
||||||
else
|
|
||||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
other_args="$other_args $1"
|
|
||||||
shift
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
conda deactivate && conda activate "$env_name"
|
|
||||||
|
|
||||||
set -x
|
|
||||||
$CONDA_PREFIX/bin/python \
|
|
||||||
-m llama_stack.distribution.server.server \
|
|
||||||
--yaml-config "$yaml_config" \
|
|
||||||
--port "$port" \
|
|
||||||
$env_vars \
|
|
||||||
$other_args
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
|
||||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
|
||||||
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
|
||||||
PYPI_VERSION=${PYPI_VERSION:-}
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
RED='\033[0;31m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
|
||||||
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
image_name="$1"
|
|
||||||
container_image="localhost/$image_name"
|
|
||||||
shift
|
|
||||||
|
|
||||||
yaml_config="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
port="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
# Initialize other_args
|
|
||||||
other_args=""
|
|
||||||
|
|
||||||
# Process environment variables from --env arguments
|
|
||||||
env_vars=""
|
|
||||||
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case "$1" in
|
|
||||||
--env)
|
|
||||||
echo "env = $2"
|
|
||||||
if [[ -n "$2" ]]; then
|
|
||||||
env_vars="$env_vars -e $2"
|
|
||||||
shift 2
|
|
||||||
else
|
|
||||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
other_args="$other_args $1"
|
|
||||||
shift
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
|
||||||
# Disable SELinux labels
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
|
||||||
fi
|
|
||||||
|
|
||||||
mounts=""
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
|
||||||
fi
|
|
||||||
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
|
||||||
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
|
||||||
version_tag="$PYPI_VERSION"
|
|
||||||
elif [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
version_tag="dev"
|
|
||||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
|
||||||
version_tag="test-$TEST_PYPI_VERSION"
|
|
||||||
else
|
|
||||||
URL="https://pypi.org/pypi/llama-stack/json"
|
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
|
||||||
fi
|
|
||||||
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
|
||||||
-p $port:$port \
|
|
||||||
$env_vars \
|
|
||||||
-v "$yaml_config:/app/config.yaml" \
|
|
||||||
$mounts \
|
|
||||||
--env LLAMA_STACK_PORT=$port \
|
|
||||||
--entrypoint python \
|
|
||||||
$container_image:$version_tag \
|
|
||||||
-m llama_stack.distribution.server.server \
|
|
||||||
--yaml-config /app/config.yaml \
|
|
||||||
$other_args
|
|
||||||
150
llama_stack/distribution/start_stack.sh
Executable file
150
llama_stack/distribution/start_stack.sh
Executable file
|
|
@ -0,0 +1,150 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||||
|
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||||
|
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
||||||
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
PYPI_VERSION=${PYPI_VERSION:-}
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
error_handler() {
|
||||||
|
echo "Error occurred in script at line: ${1}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
|
if [ $# -lt 3 ]; then
|
||||||
|
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
env_type="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
env_path_or_name="$1"
|
||||||
|
container_image="localhost/$env_path_or_name"
|
||||||
|
shift
|
||||||
|
|
||||||
|
yaml_config="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
port="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
|
# Initialize env_vars as an string
|
||||||
|
env_vars=""
|
||||||
|
other_args=""
|
||||||
|
# Process environment variables from --env arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--env)
|
||||||
|
|
||||||
|
if [[ -n "$2" ]]; then
|
||||||
|
env_vars="$env_vars --env $2"
|
||||||
|
shift 2
|
||||||
|
else
|
||||||
|
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
other_args="$other_args $1"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
PYTHON_BINARY="python"
|
||||||
|
case "$env_type" in
|
||||||
|
"venv")
|
||||||
|
# Activate virtual environment
|
||||||
|
if [ ! -d "$env_path_or_name" ]; then
|
||||||
|
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "$env_path_or_name/bin/activate" ]; then
|
||||||
|
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
source "$env_path_or_name/bin/activate"
|
||||||
|
;;
|
||||||
|
"conda")
|
||||||
|
if ! is_command_available conda; then
|
||||||
|
echo -e "${RED}Error: conda not found" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
eval "$(conda shell.bash hook)"
|
||||||
|
conda deactivate && conda activate "$env_path_or_name"
|
||||||
|
PYTHON_BINARY="$CONDA_PREFIX/bin/python"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
esac
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
|
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||||
|
--yaml-config "$yaml_config" \
|
||||||
|
--port "$port" \
|
||||||
|
$env_vars \
|
||||||
|
$other_args
|
||||||
|
elif [[ "$env_type" == "container" ]]; then
|
||||||
|
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||||
|
# Disable SELinux labels
|
||||||
|
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||||
|
fi
|
||||||
|
|
||||||
|
mounts=""
|
||||||
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
|
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
||||||
|
fi
|
||||||
|
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
||||||
|
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
|
||||||
|
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$PYPI_VERSION" ]; then
|
||||||
|
version_tag="$PYPI_VERSION"
|
||||||
|
elif [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
|
version_tag="dev"
|
||||||
|
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
|
version_tag="test-$TEST_PYPI_VERSION"
|
||||||
|
else
|
||||||
|
if ! is_command_available jq; then
|
||||||
|
echo -e "${RED}Error: jq not found" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
URL="https://pypi.org/pypi/llama-stack/json"
|
||||||
|
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||||
|
fi
|
||||||
|
|
||||||
|
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||||
|
-p $port:$port \
|
||||||
|
$env_vars \
|
||||||
|
-v "$yaml_config:/app/config.yaml" \
|
||||||
|
$mounts \
|
||||||
|
--env LLAMA_STACK_PORT=$port \
|
||||||
|
--entrypoint python \
|
||||||
|
$container_image:$version_tag \
|
||||||
|
-m llama_stack.distribution.server.server \
|
||||||
|
--yaml-config /app/config.yaml \
|
||||||
|
$other_args
|
||||||
|
fi
|
||||||
|
|
@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
echo "Using virtual environment: $venv_path"
|
||||||
# Activate virtual environment
|
# Activate virtual environment
|
||||||
if [ ! -d "$venv_path" ]; then
|
if [ ! -d "$venv_path" ]; then
|
||||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,81 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
|
|
||||||
|
|
||||||
|
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
|
env_name = ""
|
||||||
|
if image_type == ImageType.container.value or config.container_image:
|
||||||
|
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||||
|
elif image_type == ImageType.conda.value:
|
||||||
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
|
env_name = image_name or current_conda_env
|
||||||
|
if not env_name:
|
||||||
|
cprint(
|
||||||
|
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_conda_prefix(env_name):
|
||||||
|
# Conda "base" environment does not end with "base" in the
|
||||||
|
# prefix, so should be handled separately.
|
||||||
|
if env_name == "base":
|
||||||
|
return os.environ.get("CONDA_PREFIX")
|
||||||
|
# Get conda environments info
|
||||||
|
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||||
|
envs = conda_env_info["envs"]
|
||||||
|
for envpath in envs:
|
||||||
|
if envpath.endswith(env_name):
|
||||||
|
return envpath
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Using conda environment: {env_name}")
|
||||||
|
conda_prefix = get_conda_prefix(env_name)
|
||||||
|
if not conda_prefix:
|
||||||
|
cprint(
|
||||||
|
f"Conda environment {env_name} does not exist.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
||||||
|
if not build_file.exists():
|
||||||
|
cprint(
|
||||||
|
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# else must be venv since that is the only valid option left.
|
||||||
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
|
env_name = image_name or current_venv
|
||||||
|
if not env_name:
|
||||||
|
cprint(
|
||||||
|
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
print(f"Using virtual environment: {env_name}")
|
||||||
|
|
||||||
|
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
image_type,
|
||||||
|
env_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
return run_args
|
||||||
|
|
||||||
|
|
||||||
def run_with_pty(command):
|
def run_with_pty(command):
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
|
|
@ -22,6 +95,19 @@ def run_with_pty(command):
|
||||||
return _run_with_pty_unix(command)
|
return _run_with_pty_unix(command)
|
||||||
|
|
||||||
|
|
||||||
|
def in_notebook():
|
||||||
|
try:
|
||||||
|
from IPython import get_ipython
|
||||||
|
|
||||||
|
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||||
|
return False
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# run a command in a pseudo-terminal, with interrupt handling,
|
# run a command in a pseudo-terminal, with interrupt handling,
|
||||||
# useful when you want to run interactive things
|
# useful when you want to run interactive things
|
||||||
def _run_with_pty_unix(command):
|
def _run_with_pty_unix(command):
|
||||||
|
|
|
||||||
13
llama_stack/distribution/utils/image_types.py
Normal file
13
llama_stack/distribution/utils/image_types.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ImageType(Enum):
|
||||||
|
container = "container"
|
||||||
|
conda = "conda"
|
||||||
|
venv = "venv"
|
||||||
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
|
|
@ -60,9 +62,13 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for turn in turns:
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
if output_message.tool_calls and request.allow_turn_resume:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn_messages = [
|
||||||
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the steps from the turn id
|
||||||
|
steps = []
|
||||||
|
if len(turns) > 0:
|
||||||
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
# we'll create a new tool execution step with current time
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in request.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=now,
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
turn=turn,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_message = None
|
||||||
|
async for chunk in self.run(
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
input_messages=messages,
|
||||||
|
sampling_params=self.agent_config.sampling_params,
|
||||||
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
|
@ -456,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=retrieved_context or [],
|
content=retrieved_context or [],
|
||||||
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
@ -611,11 +743,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
# 1. Start the tool execution step and progress
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
@ -625,6 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
|
@ -639,6 +768,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# If tool is a builtin server tool, execute it
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
|
@ -650,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now()
|
tool_execution_start_time = datetime.now()
|
||||||
result_messages = await execute_tool_call_maybe(
|
tool_call = message.tool_calls[0]
|
||||||
|
tool_result = await execute_tool_call_maybe(
|
||||||
self.tool_runtime_api,
|
self.tool_runtime_api,
|
||||||
session_id,
|
session_id,
|
||||||
[message],
|
tool_call,
|
||||||
toolgroup_args,
|
toolgroup_args,
|
||||||
tool_to_group,
|
tool_to_group,
|
||||||
)
|
)
|
||||||
|
result_messages = [
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=tool_result.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||||
result_message = result_messages[0]
|
result_message = result_messages[0]
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
|
@ -675,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id=result_message.call_id,
|
call_id=result_message.call_id,
|
||||||
tool_name=result_message.tool_name,
|
tool_name=result_message.tool_name,
|
||||||
content=result_message.content,
|
content=result_message.content,
|
||||||
|
metadata=tool_result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
|
|
@ -913,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[CompletionMessage],
|
tool_call: ToolCall,
|
||||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||||
tool_to_group: Dict[str, str],
|
tool_to_group: Dict[str, str],
|
||||||
) -> List[ToolResponseMessage]:
|
) -> ToolInvocationResult:
|
||||||
# While Tools.run interface takes a list of messages,
|
|
||||||
# All tools currently only run on a single message
|
|
||||||
# When this changes, we can drop this assert
|
|
||||||
# Whether to call tools on each message and aggregate
|
|
||||||
# or aggregate and call tool once, reamins to be seen.
|
|
||||||
assert len(messages) == 1, "Expected single message"
|
|
||||||
message = messages[0]
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
group_name = tool_to_group.get(name, None)
|
group_name = tool_to_group.get(name, None)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
|
|
@ -946,14 +1092,7 @@ async def execute_tool_call_maybe(
|
||||||
**tool_call_args,
|
**tool_call_args,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
return [
|
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=result.content,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentCreateResponse,
|
AgentCreateResponse,
|
||||||
|
|
@ -21,6 +19,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
|
@ -68,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
# check if "bwrap" is available
|
# check if "bwrap" is available
|
||||||
if not shutil.which("bwrap"):
|
if not shutil.which("bwrap"):
|
||||||
print(
|
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
|
||||||
colored(
|
|
||||||
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
|
|
||||||
"yellow",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
|
|
@ -146,6 +140,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
|
@ -155,6 +150,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
|
allow_turn_resume=allow_turn_resume,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
@ -169,6 +165,34 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = AgentTurnResumeRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnResumeRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.resume_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> None:
|
||||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||||
|
chunks = chunks[: query_config.max_chunks]
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: query_config.max_chunks]:
|
for c in chunks:
|
||||||
metadata = c.metadata
|
metadata = c.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
|
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
metadata={
|
||||||
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
pip_packages=["sentence-transformers"],
|
pip_packages=[
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
],
|
||||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[],
|
pip_packages=[
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"pypdf",
|
||||||
|
"tqdm",
|
||||||
|
"numpy",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"nltk",
|
||||||
|
"sentencepiece",
|
||||||
|
"transformers",
|
||||||
|
],
|
||||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||||
api_dependencies=[Api.vector_io, Api.inference],
|
api_dependencies=[Api.vector_io, Api.inference],
|
||||||
|
|
|
||||||
|
|
@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import (
|
||||||
remote_provider_spec,
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
EMBEDDING_DEPS = [
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"pypdf",
|
|
||||||
"tqdm",
|
|
||||||
"numpy",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"nltk",
|
|
||||||
"sentencepiece",
|
|
||||||
"transformers",
|
|
||||||
# this happens to work because special dependencies are always installed last
|
|
||||||
# so if there was a regular torch installed first, this would be ignored
|
|
||||||
# we need a better way to do this to identify potential conflicts, etc.
|
|
||||||
# for now, this lets us significantly reduce the size of the container which
|
|
||||||
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
|
@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
|
pip_packages=["sqlite-vec"],
|
||||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
provider_type="inline::sqlite_vec",
|
||||||
|
pip_packages=["sqlite-vec"],
|
||||||
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
|
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
|
@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::chromadb",
|
provider_type="inline::chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
pip_packages=["chromadb"],
|
||||||
module="llama_stack.providers.inline.vector_io.chroma",
|
module="llama_stack.providers.inline.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="pgvector",
|
adapter_type="pgvector",
|
||||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
pip_packages=["psycopg2-binary"],
|
||||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
|
@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
pip_packages=["weaviate-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
|
|
@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="qdrant",
|
adapter_type="qdrant",
|
||||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
pip_packages=["qdrant-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ _MODEL_ENTRIES = [
|
||||||
provider_model_id="baai/bge-m3",
|
provider_model_id="baai/bge-m3",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
"embedding_dimensions": 1024,
|
"embedding_dimension": 1024,
|
||||||
"context_length": 8192,
|
"context_length": 8192,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.register_helper.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||||
# flatten the list of lists
|
# flatten the list of lists
|
||||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||||
|
|
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
llama_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Ollama does not support media for Completion requests"
|
assert not media_present, "Ollama does not support media for Completion requests"
|
||||||
|
|
@ -280,7 +281,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
|
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
|
await self.client.pull(model.provider_resource_id)
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
else:
|
else:
|
||||||
response = await self.client.ps()
|
response = await self.client.ps()
|
||||||
|
|
@ -290,7 +294,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.register_helper.register_model(model)
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
|
|
|
||||||
|
|
@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture:
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="sqlite_vec",
|
provider_id="sqlite_vec",
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig(
|
config=SQLiteVectorIOConfig(
|
||||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.inference import (
|
||||||
ModelStore,
|
ModelStore,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
@ -34,7 +35,9 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(contents)
|
embeddings = embedding_model.encode(
|
||||||
|
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||||
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
|
|
||||||
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
else:
|
else:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
if model.metadata.get("llama_model") is None:
|
llama_model = model.metadata.get("llama_model")
|
||||||
raise ValueError(
|
if llama_model is None:
|
||||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
return model
|
||||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
|
||||||
)
|
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
if existing_llama_model:
|
if existing_llama_model:
|
||||||
if existing_llama_model != model.metadata["llama_model"]:
|
if existing_llama_model != llama_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||||
)
|
)
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
||||||
else:
|
else:
|
||||||
# specific tool
|
# specific tool
|
||||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||||
|
llama_model = resolve_model(model)
|
||||||
|
if llama_model is None:
|
||||||
|
return ToolPromptFormat.json
|
||||||
|
|
||||||
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||||
|
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||||
|
):
|
||||||
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||||
|
return ToolPromptFormat.json
|
||||||
|
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||||
|
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||||
|
return ToolPromptFormat.python_list
|
||||||
|
else:
|
||||||
|
return ToolPromptFormat.json
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, List, Optional, TypeVar
|
from typing import Any, Callable, List, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebMethod:
|
class WebMethod:
|
||||||
|
|
@ -22,6 +20,13 @@ class WebMethod:
|
||||||
raw_bytes_request_body: Optional[bool] = False
|
raw_bytes_request_body: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class HasWebMethod(Protocol):
|
||||||
|
__webmethod__: WebMethod
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
|
||||||
|
|
||||||
|
|
||||||
def webmethod(
|
def webmethod(
|
||||||
route: Optional[str] = None,
|
route: Optional[str] = None,
|
||||||
method: Optional[str] = None,
|
method: Optional[str] = None,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterable
|
||||||
|
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
|
||||||
|
|
@ -39,7 +39,7 @@ class ChangedPathTracker:
|
||||||
return self._changed_paths
|
return self._changed_paths
|
||||||
|
|
||||||
|
|
||||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
|
||||||
"""Find immediate subdirectories in the templates folder."""
|
"""Find immediate subdirectories in the templates folder."""
|
||||||
if not templates_dir.exists():
|
if not templates_dir.exists():
|
||||||
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
||||||
|
|
@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
||||||
return has_changes
|
return has_changes
|
||||||
|
|
||||||
|
|
||||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
|
||||||
try:
|
try:
|
||||||
module_name = f"llama_stack.templates.{template_dir.name}"
|
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
|
||||||
pytest_args,
|
pytest_args,
|
||||||
"-s",
|
"-s",
|
||||||
"-v",
|
"-v",
|
||||||
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
|
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::cerebras
|
- remote::cerebras
|
||||||
|
- inline::sentence-transformers
|
||||||
safety:
|
safety:
|
||||||
- inline::llama-guard
|
- inline::llama-guard
|
||||||
vector_io:
|
vector_io:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::cerebras"],
|
"inference": ["remote::cerebras", "inline::sentence-transformers"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
7
llama_stack/templates/ci-tests/__init__.py
Normal file
7
llama_stack/templates/ci-tests/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .ci_tests import get_distribution_template # noqa: F401
|
||||||
33
llama_stack/templates/ci-tests/build.yaml
Normal file
33
llama_stack/templates/ci-tests/build.yaml
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Distribution for running e2e tests in CI
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::fireworks
|
||||||
|
- inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- inline::sqlite-vec
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
||||||
123
llama_stack/templates/ci-tests/ci_tests.py
Normal file
123
llama_stack/templates/ci-tests/ci_tests.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::fireworks", "inline::sentence-transformers"],
|
||||||
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
name = "ci-tests"
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="fireworks",
|
||||||
|
provider_type="remote::fireworks",
|
||||||
|
config=FireworksImplConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
vector_io_provider = Provider(
|
||||||
|
provider_id="sqlite-vec",
|
||||||
|
provider_type="inline::sqlite-vec",
|
||||||
|
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||||
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||||
|
default_models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||||
|
provider_model_id=m.provider_model_id,
|
||||||
|
provider_id="fireworks",
|
||||||
|
metadata=m.metadata,
|
||||||
|
model_type=m.model_type,
|
||||||
|
)
|
||||||
|
for m in MODEL_ENTRIES
|
||||||
|
]
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name=name,
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Distribution for running e2e tests in CI",
|
||||||
|
container_image=None,
|
||||||
|
template_path=None,
|
||||||
|
providers=providers,
|
||||||
|
default_models=default_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider, embedding_provider],
|
||||||
|
"vector_io": [vector_io_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models + [embedding_model],
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMA_STACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"FIREWORKS_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Fireworks API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
169
llama_stack/templates/ci-tests/run.yaml
Normal file
169
llama_stack/templates/ci-tests/run.yaml
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: ci-tests
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: fireworks
|
||||||
|
provider_type: remote::fireworks
|
||||||
|
config:
|
||||||
|
url: https://api.fireworks.ai/inference/v1
|
||||||
|
api_key: ${env.FIREWORKS_API_KEY}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config: {}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 768
|
||||||
|
context_length: 8192
|
||||||
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
|
shields:
|
||||||
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
|
@ -5,6 +5,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::tgi
|
- remote::tgi
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::tgi"],
|
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::fireworks
|
- remote::fireworks
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,14 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::fireworks"],
|
"inference": ["remote::fireworks", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::hf::serverless
|
- remote::hf::serverless
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::hf::serverless"],
|
"inference": ["remote::hf::serverless", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ models:
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimensions: 1024
|
embedding_dimension: 1024
|
||||||
context_length: 8192
|
context_length: 8192
|
||||||
model_id: baai/bge-m3
|
model_id: baai/bge-m3
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ distribution_spec:
|
||||||
inference:
|
inference:
|
||||||
- remote::ollama
|
- remote::ollama
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::sqlite-vec
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
||||||
SentenceTransformersInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::ollama"],
|
"inference": ["remote::ollama"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
|
@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::ollama",
|
provider_type="remote::ollama",
|
||||||
config=OllamaImplConfig.sample_run_config(),
|
config=OllamaImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
embedding_provider = Provider(
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
provider_type="inline::sentence-transformers",
|
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
|
||||||
)
|
|
||||||
vector_io_provider_faiss = Provider(
|
|
||||||
provider_id="faiss",
|
|
||||||
provider_type="inline::faiss",
|
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
|
||||||
)
|
|
||||||
vector_io_provider_sqlite = Provider(
|
vector_io_provider_sqlite = Provider(
|
||||||
provider_id="sqlite_vec",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider],
|
||||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
|
"vector_io": [vector_io_provider_sqlite],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, embedding_model],
|
default_models=[inference_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [inference_provider],
|
||||||
inference_provider,
|
"vector_io": [vector_io_provider_sqlite],
|
||||||
embedding_provider,
|
|
||||||
],
|
|
||||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
|
|
||||||
"safety": [
|
"safety": [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="llama-guard",
|
provider_id="llama-guard",
|
||||||
|
|
|
||||||
|
|
@ -16,24 +16,11 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
config: {}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::faiss
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
kvstore:
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
- provider_id: faiss
|
|
||||||
provider_type: inline::faiss
|
|
||||||
config:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
||||||
|
|
@ -16,19 +16,9 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
config: {}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::faiss
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
- provider_id: sqlite_vec
|
|
||||||
provider_type: inline::sqlite_vec
|
|
||||||
config:
|
config:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
safety:
|
safety:
|
||||||
|
|
@ -97,12 +87,6 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: ollama
|
|
||||||
provider_model_id: all-minilm:latest
|
|
||||||
model_type: embedding
|
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::vllm
|
- remote::vllm
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::vllm"],
|
"inference": ["remote::vllm", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::tgi
|
- remote::tgi
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::tgi"],
|
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::together
|
- remote::together
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::together"],
|
"inference": ["remote::together", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- inline::vllm
|
- inline::vllm
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import (
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["inline::vllm"],
|
"inference": ["inline::vllm", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llama_stack"
|
name = "llama_stack"
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||||
description = "Llama Stack"
|
description = "Llama Stack"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
@ -26,8 +26,8 @@ dependencies = [
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface-hub",
|
"huggingface-hub",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-models>=0.1.3",
|
"llama-models>=0.1.4",
|
||||||
"llama-stack-client>=0.1.3",
|
"llama-stack-client>=0.1.4",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"pydantic>=2",
|
"pydantic>=2",
|
||||||
|
|
@ -158,3 +158,26 @@ ignore = [
|
||||||
"B007",
|
"B007",
|
||||||
"B008",
|
"B008",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
mypy_path = ["llama_stack"]
|
||||||
|
packages = ["llama_stack"]
|
||||||
|
disable_error_code = []
|
||||||
|
warn_return_any = true
|
||||||
|
# # honor excludes by not following there through imports
|
||||||
|
follow_imports = "silent"
|
||||||
|
exclude = [
|
||||||
|
# As we fix more and more of these, we should remove them from the list
|
||||||
|
"llama_stack/providers",
|
||||||
|
"llama_stack/distribution",
|
||||||
|
"llama_stack/apis",
|
||||||
|
"llama_stack/cli",
|
||||||
|
"llama_stack/models",
|
||||||
|
"llama_stack/strong_typing",
|
||||||
|
"llama_stack/templates",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
||||||
|
module = ["llama_models.*", "yaml", "fire"]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,13 @@ fsspec==2025.2.0
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
httpcore==1.0.7
|
httpcore==1.0.7
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
huggingface-hub==0.28.1
|
huggingface-hub==0.29.0
|
||||||
idna==3.10
|
idna==3.10
|
||||||
jinja2==3.1.5
|
jinja2==3.1.5
|
||||||
jsonschema==4.23.0
|
jsonschema==4.23.0
|
||||||
jsonschema-specifications==2024.10.1
|
jsonschema-specifications==2024.10.1
|
||||||
llama-models==0.1.3
|
llama-models==0.1.4
|
||||||
llama-stack-client==0.1.3
|
llama-stack-client==0.1.4
|
||||||
lxml==5.3.1
|
lxml==5.3.1
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
markupsafe==3.0.2
|
markupsafe==3.0.2
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
from llama_stack_client.types.tool_def_param import Parameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
|
from llama_stack.apis.agents.agents import (
|
||||||
from llama_stack.apis.agents.agents import ToolChoice
|
AgentConfig as Server__AgentConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.agents.agents import (
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
|
@ -86,7 +90,6 @@ class TestClientTool(ClientTool):
|
||||||
def agent_config(llama_stack_client, text_model_id):
|
def agent_config(llama_stack_client, text_model_id):
|
||||||
available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
|
available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
|
||||||
available_shields = available_shields[:1]
|
available_shields = available_shields[:1]
|
||||||
print(f"Using shield: {available_shields}")
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
|
|
@ -322,17 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice(llama_stack_client, agent_config):
|
def test_tool_choice(llama_stack_client, agent_config):
|
||||||
data = [
|
def run_agent(tool_choice):
|
||||||
("required", '{"type": "function"'),
|
client_tool = TestClientTool()
|
||||||
("none", None),
|
|
||||||
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
|
|
||||||
]
|
|
||||||
client_tool = TestClientTool()
|
|
||||||
for tool_choice, expected_tool in data:
|
|
||||||
agent_config["tool_config"] = {"tool_choice": tool_choice}
|
|
||||||
agent_config["client_tools"] = [client_tool.get_tool_definition()]
|
|
||||||
|
|
||||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
test_agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"tool_config": {"tool_choice": tool_choice},
|
||||||
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
@ -343,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||||
logs_str = "".join(logs)
|
|
||||||
if expected_tool:
|
tool_execution_steps = run_agent("required")
|
||||||
assert expected_tool in logs_str
|
assert len(tool_execution_steps) > 0
|
||||||
else:
|
|
||||||
assert '{"type": "function"' not in logs_str
|
tool_execution_steps = run_agent("none")
|
||||||
|
assert len(tool_execution_steps) == 0
|
||||||
|
|
||||||
|
tool_execution_steps = run_agent("get_boiling_point")
|
||||||
|
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
||||||
|
|
||||||
|
|
||||||
# TODO: fix this flaky test
|
# TODO: fix this flaky test
|
||||||
|
|
@ -378,7 +385,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
print(logs_str)
|
|
||||||
# can't tell a joke: "I don't have a function"
|
# can't tell a joke: "I don't have a function"
|
||||||
assert "function" in logs_str
|
assert "function" in logs_str
|
||||||
|
|
||||||
|
|
@ -417,7 +423,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
print(logs_str)
|
|
||||||
assert "bicycle" in logs_str
|
assert "bicycle" in logs_str
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
@ -432,7 +437,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
print(logs_str)
|
|
||||||
assert "-100" in logs_str
|
assert "-100" in logs_str
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
@ -453,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
|
@ -484,15 +489,17 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
for prompt, expected_kw in user_prompts:
|
for prompt, expected_kw in user_prompts:
|
||||||
print(f"User> {prompt}")
|
|
||||||
response = rag_agent.create_turn(
|
response = rag_agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
# rag is called
|
||||||
logs_str = "".join(logs)
|
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||||
assert "Tool:query_from_memory" in logs_str
|
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
|
||||||
assert expected_kw in logs_str.lower()
|
# document ids are present in metadata
|
||||||
|
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||||
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
|
|
@ -548,7 +555,6 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt, docs, tool_name in user_prompts:
|
for prompt, docs, tool_name in user_prompts:
|
||||||
print(f"User> {prompt}")
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
|
|
||||||
|
|
@ -42,28 +42,30 @@ def pytest_addoption(parser):
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--inference-model",
|
"--inference-model",
|
||||||
action="store",
|
|
||||||
default=TEXT_MODEL,
|
default=TEXT_MODEL,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--vision-inference-model",
|
"--vision-inference-model",
|
||||||
action="store",
|
|
||||||
default=VISION_MODEL,
|
default=VISION_MODEL,
|
||||||
help="Specify the vision inference model to use for testing",
|
help="Specify the vision inference model to use for testing",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-shield",
|
"--safety-shield",
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-Guard-3-1B",
|
default="meta-llama/Llama-Guard-3-1B",
|
||||||
help="Specify the safety shield model to use for testing",
|
help="Specify the safety shield model to use for testing",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
action="store",
|
default=None,
|
||||||
default=TEXT_MODEL,
|
|
||||||
help="Specify the embedding model to use for testing",
|
help="Specify the embedding model to use for testing",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-dimension",
|
||||||
|
type=int,
|
||||||
|
default=384,
|
||||||
|
help="Output dimensionality of the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
@ -78,7 +80,7 @@ def provider_data():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def llama_stack_client(provider_data):
|
def llama_stack_client(provider_data, text_model_id):
|
||||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||||
|
|
@ -95,25 +97,91 @@ def llama_stack_client(provider_data):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_provider_type(llama_stack_client):
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
inference_providers = [p for p in providers if p.api == "inference"]
|
||||||
|
assert len(inference_providers) > 0, "No inference providers found"
|
||||||
|
return inference_providers[0].provider_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
|
||||||
|
client = llama_stack_client
|
||||||
|
|
||||||
|
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||||
|
assert len(providers) > 0, "No inference providers found"
|
||||||
|
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||||
|
if text_model_id:
|
||||||
|
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||||
|
if vision_model_id:
|
||||||
|
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
||||||
|
|
||||||
|
if embedding_model_id and embedding_dimension:
|
||||||
|
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||||
|
selected_provider = None
|
||||||
|
for p in providers:
|
||||||
|
if p.provider_type == "inline::sentence-transformers":
|
||||||
|
selected_provider = p
|
||||||
|
break
|
||||||
|
|
||||||
|
selected_provider = selected_provider or providers[0]
|
||||||
|
client.models.register(
|
||||||
|
model_id=embedding_model_id,
|
||||||
|
provider_id=selected_provider.provider_id,
|
||||||
|
model_type="embedding",
|
||||||
|
metadata={"embedding_dimension": embedding_dimension},
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_SHORT_IDS = {
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": "8B",
|
||||||
|
"meta-llama/Llama-3.2-11B-Vision-Instruct": "11B",
|
||||||
|
"all-MiniLM-L6-v2": "MiniLM",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_short_id(value):
|
||||||
|
return MODEL_SHORT_IDS.get(value, value)
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
params = []
|
||||||
|
values = []
|
||||||
|
id_parts = []
|
||||||
|
|
||||||
if "text_model_id" in metafunc.fixturenames:
|
if "text_model_id" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
params.append("text_model_id")
|
||||||
"text_model_id",
|
val = metafunc.config.getoption("--inference-model")
|
||||||
[metafunc.config.getoption("--inference-model")],
|
values.append(val)
|
||||||
scope="session",
|
id_parts.append(f"txt={get_short_id(val)}")
|
||||||
)
|
|
||||||
if "vision_model_id" in metafunc.fixturenames:
|
if "vision_model_id" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
params.append("vision_model_id")
|
||||||
"vision_model_id",
|
val = metafunc.config.getoption("--vision-inference-model")
|
||||||
[metafunc.config.getoption("--vision-inference-model")],
|
values.append(val)
|
||||||
scope="session",
|
id_parts.append(f"vis={get_short_id(val)}")
|
||||||
)
|
|
||||||
if "embedding_model_id" in metafunc.fixturenames:
|
if "embedding_model_id" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
params.append("embedding_model_id")
|
||||||
"embedding_model_id",
|
val = metafunc.config.getoption("--embedding-model")
|
||||||
[metafunc.config.getoption("--embedding-model")],
|
values.append(val)
|
||||||
scope="session",
|
if val is not None:
|
||||||
)
|
id_parts.append(f"emb={get_short_id(val)}")
|
||||||
|
|
||||||
|
if "embedding_dimension" in metafunc.fixturenames:
|
||||||
|
params.append("embedding_dimension")
|
||||||
|
val = metafunc.config.getoption("--embedding-dimension")
|
||||||
|
values.append(val)
|
||||||
|
if val != 384:
|
||||||
|
id_parts.append(f"dim={val}")
|
||||||
|
|
||||||
|
if params:
|
||||||
|
# Create a single test ID string
|
||||||
|
test_id = ":".join(id_parts)
|
||||||
|
metafunc.parametrize(params, [values], scope="session", ids=[test_id])
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue