mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider
This commit is contained in:
commit
ca6a12e362
114 changed files with 2100 additions and 685 deletions
31
.github/workflows/update-readthedocs.yml
vendored
31
.github/workflows/update-readthedocs.yml
vendored
|
@ -11,17 +11,42 @@ on:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
paths:
|
paths:
|
||||||
- 'docs/source/**'
|
- 'docs/**'
|
||||||
- 'docs/resources/**'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- 'docs/**'
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-readthedocs:
|
update-readthedocs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
|
||||||
|
- name: Sync with uv
|
||||||
|
run: uv sync --extra docs
|
||||||
|
|
||||||
|
- name: Build HTML
|
||||||
|
run: |
|
||||||
|
cd docs
|
||||||
|
uv run make html
|
||||||
|
|
||||||
- name: Trigger ReadTheDocs build
|
- name: Trigger ReadTheDocs build
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
run: |
|
run: |
|
||||||
if [ -z "$TOKEN" ]; then
|
if [ -z "$TOKEN" ]; then
|
||||||
echo "READTHEDOCS_TOKEN is not set"
|
echo "READTHEDOCS_TOKEN is not set"
|
||||||
|
|
|
@ -30,6 +30,7 @@ repos:
|
||||||
rev: v0.9.4
|
rev: v0.9.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [ --fix ]
|
||||||
exclude: ^llama_stack/strong_typing/.*$
|
exclude: ^llama_stack/strong_typing/.*$
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
@ -45,23 +46,26 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
args: [
|
args: [
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
files: ^pyproject\.toml$
|
files: ^pyproject\.toml$
|
||||||
- id: uv-sync
|
- id: uv-sync
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
# rev: v1.14.0
|
rev: v1.15.0
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# additional_dependencies:
|
additional_dependencies:
|
||||||
# - types-requests
|
- uv==0.6.2
|
||||||
# - types-setuptools
|
- mypy
|
||||||
# - pydantic
|
- pytest
|
||||||
# args: [--ignore-missing-imports]
|
- rich
|
||||||
|
- types-requests
|
||||||
|
- pydantic
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
# - repo: https://github.com/jsh9/pydoclint
|
# - repo: https://github.com/jsh9/pydoclint
|
||||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||||
|
|
|
@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
|
||||||
$ cd llama-stack/docs
|
$ cd llama-stack/docs
|
||||||
$ uv sync --extra docs
|
$ uv sync --extra docs
|
||||||
|
|
||||||
|
# This rebuilds the documentation pages.
|
||||||
|
$ uv run make html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
$ make html
|
$ uv run sphinx-autobuild source build/html --write-all
|
||||||
$ uv run sphinx-autobuild source build/html
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
@ -145,7 +147,7 @@ If you modify or add new API endpoints, update the API documentation accordingly
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ uv sync --extra dev
|
$ uv sync --extra dev
|
||||||
$ ./docs/openapi_generator/run_openapi_generator.sh
|
$ uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||||
|
|
|
@ -3,3 +3,4 @@ include distributions/dependencies.json
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
include llama_stack/cli/scripts/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_stack/templates/*/*.yaml
|
include llama_stack/templates/*/*.yaml
|
||||||
|
include llama_stack/providers/tests/test_cases/*.json
|
||||||
|
|
12
README.md
12
README.md
|
@ -78,18 +78,14 @@ You have two ways to install this repository:
|
||||||
```
|
```
|
||||||
|
|
||||||
* **Install from source**:
|
* **Install from source**:
|
||||||
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
|
||||||
Then, run the following commands:
|
Then, run the following commands:
|
||||||
```bash
|
```bash
|
||||||
mkdir -p ~/local
|
|
||||||
cd ~/local
|
|
||||||
git clone git@github.com:meta-llama/llama-stack.git
|
git clone git@github.com:meta-llama/llama-stack.git
|
||||||
|
|
||||||
conda create -n stack python=3.10
|
|
||||||
conda activate stack
|
|
||||||
|
|
||||||
cd llama-stack
|
cd llama-stack
|
||||||
pip install -e .
|
|
||||||
|
uv sync
|
||||||
|
uv pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
|
@ -30,9 +30,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"cerebras": [
|
"cerebras": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
@ -170,9 +168,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"hf-serverless": [
|
"hf-serverless": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
@ -247,9 +243,7 @@
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"zmq",
|
"zmq"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"meta-reference-quantized-gpu": [
|
"meta-reference-quantized-gpu": [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
@ -290,9 +284,7 @@
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"zmq",
|
"zmq"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"nvidia": [
|
"nvidia": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
@ -323,9 +315,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
@ -335,7 +325,6 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
@ -356,11 +345,10 @@
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"remote-vllm": [
|
"remote-vllm": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
@ -423,9 +411,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn"
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
],
|
||||||
"tgi": [
|
"tgi": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|
232
docs/_static/llama-stack-spec.html
vendored
232
docs/_static/llama-stack-spec.html
vendored
|
@ -2315,6 +2315,70 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "agent_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the agent to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "session_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the session to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "turn_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the turn to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -4226,6 +4290,9 @@
|
||||||
},
|
},
|
||||||
"tool_config": {
|
"tool_config": {
|
||||||
"$ref": "#/components/schemas/ToolConfig"
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
|
},
|
||||||
|
"allow_turn_resume": {
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -4454,6 +4521,31 @@
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -4612,6 +4704,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -4621,7 +4716,8 @@
|
||||||
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
||||||
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
||||||
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
||||||
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
|
||||||
|
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4784,6 +4880,25 @@
|
||||||
"title": "AgentTurnResponseStreamChunk",
|
"title": "AgentTurnResponseStreamChunk",
|
||||||
"description": "streamed agent turn completion response."
|
"description": "streamed agent turn completion response."
|
||||||
},
|
},
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"event_type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "turn_awaiting_input",
|
||||||
|
"default": "turn_awaiting_input"
|
||||||
|
},
|
||||||
|
"turn": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"event_type",
|
||||||
|
"turn"
|
||||||
|
],
|
||||||
|
"title": "AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
|
},
|
||||||
"AgentTurnResponseTurnCompletePayload": {
|
"AgentTurnResponseTurnCompletePayload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4929,11 +5044,42 @@
|
||||||
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||||
},
|
},
|
||||||
"contents": {
|
"contents": {
|
||||||
"type": "array",
|
"oneOf": [
|
||||||
"items": {
|
{
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"type": "array",
|
||||||
},
|
"items": {
|
||||||
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||||
|
},
|
||||||
|
"text_truncation": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"none",
|
||||||
|
"start",
|
||||||
|
"end"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
|
||||||
|
},
|
||||||
|
"output_dimension": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
|
||||||
|
},
|
||||||
|
"task_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"query",
|
||||||
|
"document"
|
||||||
|
],
|
||||||
|
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -6625,6 +6771,31 @@
|
||||||
},
|
},
|
||||||
"error_code": {
|
"error_code": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7474,9 +7645,37 @@
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"metadata"
|
||||||
|
],
|
||||||
"title": "RAGQueryResult"
|
"title": "RAGQueryResult"
|
||||||
},
|
},
|
||||||
"QueryChunksRequest": {
|
"QueryChunksRequest": {
|
||||||
|
@ -8015,6 +8214,27 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"ResumeAgentTurnRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_responses": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||||
|
},
|
||||||
|
"description": "The tool call responses to resume the turn with."
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to stream the response."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"tool_responses"
|
||||||
|
],
|
||||||
|
"title": "ResumeAgentTurnRequest"
|
||||||
|
},
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
152
docs/_static/llama-stack-spec.yaml
vendored
152
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1401,6 +1401,53 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||||
|
objects.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: >-
|
||||||
|
Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client
|
||||||
|
side tool calls, this endpoint can be used to submit the outputs from the
|
||||||
|
tool calls once they are ready.
|
||||||
|
parameters:
|
||||||
|
- name: agent_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the agent to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: session_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the session to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: turn_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the turn to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ResumeAgentTurnRequest'
|
||||||
|
required: true
|
||||||
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
@ -2740,6 +2787,8 @@ components:
|
||||||
$ref: '#/components/schemas/AgentTool'
|
$ref: '#/components/schemas/AgentTool'
|
||||||
tool_config:
|
tool_config:
|
||||||
$ref: '#/components/schemas/ToolConfig'
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
allow_turn_resume:
|
||||||
|
type: boolean
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- messages
|
- messages
|
||||||
|
@ -2896,6 +2945,16 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- call_id
|
- call_id
|
||||||
|
@ -2992,6 +3051,7 @@ components:
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: event_type
|
propertyName: event_type
|
||||||
mapping:
|
mapping:
|
||||||
|
@ -3000,6 +3060,7 @@ components:
|
||||||
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
AgentTurnResponseStepCompletePayload:
|
AgentTurnResponseStepCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3106,6 +3167,21 @@ components:
|
||||||
- event
|
- event
|
||||||
title: AgentTurnResponseStreamChunk
|
title: AgentTurnResponseStreamChunk
|
||||||
description: streamed agent turn completion response.
|
description: streamed agent turn completion response.
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
event_type:
|
||||||
|
type: string
|
||||||
|
const: turn_awaiting_input
|
||||||
|
default: turn_awaiting_input
|
||||||
|
turn:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- event_type
|
||||||
|
- turn
|
||||||
|
title: >-
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload
|
||||||
AgentTurnResponseTurnCompletePayload:
|
AgentTurnResponseTurnCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3224,13 +3300,39 @@ components:
|
||||||
The identifier of the model to use. The model must be an embedding model
|
The identifier of the model to use. The model must be an embedding model
|
||||||
registered with Llama Stack and available via the /models endpoint.
|
registered with Llama Stack and available via the /models endpoint.
|
||||||
contents:
|
contents:
|
||||||
type: array
|
oneOf:
|
||||||
items:
|
- type: array
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
items:
|
||||||
|
type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/InterleavedContentItem'
|
||||||
description: >-
|
description: >-
|
||||||
List of contents to generate embeddings for. Note that content can be
|
List of contents to generate embeddings for. Each content can be a string
|
||||||
multimodal. The behavior depends on the model and provider. Some models
|
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||||
may only support text.
|
depends on the model and provider. Some models may only support text.
|
||||||
|
text_truncation:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- none
|
||||||
|
- start
|
||||||
|
- end
|
||||||
|
description: >-
|
||||||
|
(Optional) Config for how to truncate text for embedding when text is
|
||||||
|
longer than the model's max sequence length.
|
||||||
|
output_dimension:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Output dimensionality for the embeddings. Only supported by
|
||||||
|
Matryoshka models.
|
||||||
|
task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- query
|
||||||
|
- document
|
||||||
|
description: >-
|
||||||
|
(Optional) How is the embedding being used? This is only supported by
|
||||||
|
asymmetric embedding models.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model_id
|
- model_id
|
||||||
|
@ -4289,6 +4391,16 @@ components:
|
||||||
type: string
|
type: string
|
||||||
error_code:
|
error_code:
|
||||||
type: integer
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
|
@ -4862,7 +4974,19 @@ components:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- metadata
|
||||||
title: RAGQueryResult
|
title: RAGQueryResult
|
||||||
QueryChunksRequest:
|
QueryChunksRequest:
|
||||||
type: object
|
type: object
|
||||||
|
@ -5179,6 +5303,22 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
ResumeAgentTurnRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_responses:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ToolResponseMessage'
|
||||||
|
description: >-
|
||||||
|
The tool call responses to resume the turn with.
|
||||||
|
stream:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to stream the response.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- tool_responses
|
||||||
|
title: ResumeAgentTurnRequest
|
||||||
RunEvalRequest:
|
RunEvalRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -86,8 +86,6 @@
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!apt-get install -y bubblewrap\n",
|
"!apt-get install -y bubblewrap\n",
|
||||||
"import os\n",
|
|
||||||
"os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n",
|
|
||||||
"!pip install uv\n",
|
"!pip install uv\n",
|
||||||
"!uv pip install llama-stack"
|
"!uv pip install llama-stack"
|
||||||
]
|
]
|
||||||
|
@ -3632,7 +3630,7 @@
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "master",
|
"display_name": "toolchain",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
|
|
@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
||||||
## API Providers
|
## API Providers
|
||||||
|
|
||||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.),
|
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ Providers come in two flavors:
|
||||||
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||||
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||||
|
|
||||||
Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally.
|
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
from docutils import nodes
|
from docutils import nodes
|
||||||
|
|
||||||
project = "llama-stack"
|
project = "llama-stack"
|
||||||
copyright = "2024, Meta"
|
copyright = "2025, Meta"
|
||||||
author = "Meta"
|
author = "Meta"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
|
@ -38,6 +38,7 @@ The following models are available by default:
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
|
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
|
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
|
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
|
||||||
|
- `baai/bge-m3 (baai/bge-m3)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
||||||
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
|
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
|
||||||
|
|
||||||
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
|
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
|
||||||
- {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama))
|
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
||||||
|
|
||||||
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
||||||
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
||||||
|
@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
||||||
- [Android](ondevice_distro/android_sdk)
|
- [Android](ondevice_distro/android_sdk)
|
||||||
|
|
||||||
|
|
||||||
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).**
|
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).**
|
||||||
|
|
||||||
### Distribution Details
|
### Distribution Details
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::cerebras` |
|
| inference | `remote::cerebras`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::tgi` |
|
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::fireworks` |
|
| inference | `remote::fireworks`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||||
|
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::vllm` |
|
| inference | `remote::vllm`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::tgi` |
|
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::together` |
|
| inference | `remote::together`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
||||||
| **Provider** | **Environments** |
|
| **Provider** | **Environments** |
|
||||||
| :----: | :----: |
|
| :----: | :----: |
|
||||||
| FAISS | Single Node |
|
| FAISS | Single Node |
|
||||||
|
| SQLite-Vec| Single Node |
|
||||||
| Chroma | Hosted and Single Node |
|
| Chroma | Hosted and Single Node |
|
||||||
| Postgres (PGVector) | Hosted and Single Node |
|
| Postgres (PGVector) | Hosted and Single Node |
|
||||||
| Weaviate | Hosted |
|
| Weaviate | Hosted |
|
||||||
|
@ -88,6 +89,7 @@ self
|
||||||
introduction/index
|
introduction/index
|
||||||
getting_started/index
|
getting_started/index
|
||||||
concepts/index
|
concepts/index
|
||||||
|
providers/index
|
||||||
distributions/index
|
distributions/index
|
||||||
distributions/selection
|
distributions/selection
|
||||||
building_applications/index
|
building_applications/index
|
||||||
|
|
59
docs/source/providers/index.md
Normal file
59
docs/source/providers/index.md
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Providers Overview
|
||||||
|
|
||||||
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
|
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||||
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
|
Providers come in two flavors:
|
||||||
|
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||||
|
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||||
|
|
||||||
|
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
|
|
||||||
|
## Agents
|
||||||
|
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||||
|
|
||||||
|
## DatasetIO
|
||||||
|
Interfaces with datasets and data loaders.
|
||||||
|
|
||||||
|
## Eval
|
||||||
|
Generates outputs (via Inference or Agents) and perform scoring.
|
||||||
|
|
||||||
|
## Inference
|
||||||
|
Runs inference with an LLM.
|
||||||
|
|
||||||
|
## Post Training
|
||||||
|
Fine-tunes a model.
|
||||||
|
|
||||||
|
## Safety
|
||||||
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
## Scoring
|
||||||
|
Evaluates the outputs of the system.
|
||||||
|
|
||||||
|
## Telemetry
|
||||||
|
Collects telemetry data from the system.
|
||||||
|
|
||||||
|
## Tool Runtime
|
||||||
|
Is associated with the ToolGroup resouces.
|
||||||
|
|
||||||
|
## Vector IO
|
||||||
|
|
||||||
|
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
|
||||||
|
Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector
|
||||||
|
io and database are used to store and retrieve documents for retrieval.
|
||||||
|
|
||||||
|
#### Vector IO Providers
|
||||||
|
The following providers (i.e., databases) are available for Vector IO:
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
vector_io/faiss
|
||||||
|
vector_io/sqlite-vec
|
||||||
|
vector_io/chromadb
|
||||||
|
vector_io/pgvector
|
||||||
|
vector_io/qdrant
|
||||||
|
vector_io/weaviate
|
||||||
|
```
|
36
docs/source/providers/vector_io/chromadb.md
Normal file
36
docs/source/providers/vector_io/chromadb.md
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Chroma
|
||||||
|
|
||||||
|
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||||
|
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
Chroma supports:
|
||||||
|
- Store embeddings and their metadata
|
||||||
|
- Vector search
|
||||||
|
- Full-text search
|
||||||
|
- Document storage
|
||||||
|
- Metadata filtering
|
||||||
|
- Multi-modal retrieval
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Chrome in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use chroma.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install chroma using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install chromadb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
33
docs/source/providers/vector_io/faiss.md
Normal file
33
docs/source/providers/vector_io/faiss.md
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Faiss
|
||||||
|
|
||||||
|
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Lightweight and easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Faiss in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Faiss using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install faiss-cpu
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
|
||||||
|
more details about Faiss in general.
|
31
docs/source/providers/vector_io/pgvector.md
Normal file
31
docs/source/providers/vector_io/pgvector.md
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Postgres PGVector
|
||||||
|
|
||||||
|
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use PGVector in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install PGVector using docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull pgvector/pgvector:pg17
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
31
docs/source/providers/vector_io/qdrant.md
Normal file
31
docs/source/providers/vector_io/qdrant.md
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Qdrant
|
||||||
|
|
||||||
|
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly in memory.
|
||||||
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Qdrant in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Faiss.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Qdrant using docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull qdrant/qdrant
|
||||||
|
```
|
||||||
|
## Documentation
|
||||||
|
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# SQLite-Vec
|
||||||
|
|
||||||
|
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly within an SQLite database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Lightweight and easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use SQLite-Vec in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install SQLite-Vec using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sqlite-vec
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
33
docs/source/providers/vector_io/weaviate.md
Normal file
33
docs/source/providers/vector_io/weaviate.md
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Weaviate
|
||||||
|
|
||||||
|
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||||
|
It allows you to store and query vectors directly within a Weaviate database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
Weaviate supports:
|
||||||
|
- Store embeddings and their metadata
|
||||||
|
- Vector search
|
||||||
|
- Full-text search
|
||||||
|
- Hybrid search
|
||||||
|
- Document storage
|
||||||
|
- Metadata filtering
|
||||||
|
- Multi-modal retrieval
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Weaviate in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use chroma.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
|
@ -171,7 +171,7 @@ The `llama model` command helps you explore the model’s interface.
|
||||||
llama model --help
|
llama model --help
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
usage: llama model [-h] {download,list,prompt-format,describe} ...
|
usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ...
|
||||||
|
|
||||||
Work with llama models
|
Work with llama models
|
||||||
|
|
||||||
|
@ -179,15 +179,15 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
||||||
model_subcommands:
|
model_subcommands:
|
||||||
{download,list,prompt-format,describe}
|
{download,list,prompt-format,describe,verify-download,remove}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Describe
|
||||||
|
|
||||||
You can use the describe command to know more about a model:
|
You can use the describe command to know more about a model:
|
||||||
```
|
```
|
||||||
llama model describe -m Llama3.2-3B-Instruct
|
llama model describe -m Llama3.2-3B-Instruct
|
||||||
```
|
```
|
||||||
### Describe
|
|
||||||
|
|
||||||
```
|
```
|
||||||
+-----------------------------+----------------------------------+
|
+-----------------------------+----------------------------------+
|
||||||
| Model | Llama3.2-3B-Instruct |
|
| Model | Llama3.2-3B-Instruct |
|
||||||
|
@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
||||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||||
|
|
||||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||||
|
|
||||||
|
### Remove model
|
||||||
|
You can run `llama model remove` to remove unecessary model:
|
||||||
|
|
||||||
|
```
|
||||||
|
llama model remove -m Llama-Guard-3-8B-int8
|
||||||
|
```
|
||||||
|
|
|
@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
turn_start = "turn_start"
|
turn_start = "turn_start"
|
||||||
turn_complete = "turn_complete"
|
turn_complete = "turn_complete"
|
||||||
|
turn_awaiting_input = "turn_awaiting_input"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||||
|
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||||
|
)
|
||||||
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = register_schema(
|
AgentTurnResponseEventPayload = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[
|
Union[
|
||||||
|
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
],
|
],
|
||||||
|
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
|
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||||
|
allow_turn_resume: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: str
|
||||||
|
tool_responses: List[ToolResponseMessage]
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
|
@ -333,8 +355,34 @@ class Agents(Protocol):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||||
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to resume.
|
||||||
|
:param session_id: The ID of the session to resume.
|
||||||
|
:param turn_id: The ID of the turn to resume.
|
||||||
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
|
:param stream: Whether to stream the response.
|
||||||
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
|
@ -91,15 +91,18 @@ ParamType = register_schema(
|
||||||
name="ParamType",
|
name="ParamType",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
# TODO: recursive definition of ParamType in these containers
|
# TODO: recursive definition of ParamType in these containers
|
||||||
# will cause infinite recursion in OpenAPI generation script
|
# will cause infinite recursion in OpenAPI generation script
|
||||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||||
# ArrayType.model_rebuild()
|
ArrayType.model_rebuild()
|
||||||
# ObjectType.model_rebuild()
|
ObjectType.model_rebuild()
|
||||||
# UnionType.model_rebuild()
|
UnionType.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
# class CustomType(BaseModel):
|
class CustomType(BaseModel):
|
||||||
# type: Literal["custom"] = "custom"
|
pylint: disable=syntax-error
|
||||||
# validator_class: str
|
type: Literal["custom"] = "custom"
|
||||||
|
validator_class: str
|
||||||
|
"""
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing import (
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: Union[BuiltinTool, str]
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -402,6 +403,30 @@ class ModelStore(Protocol):
|
||||||
def get_model(self, identifier: str) -> Model: ...
|
def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
|
||||||
|
class TextTruncation(Enum):
|
||||||
|
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||||
|
|
||||||
|
:cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
|
||||||
|
:cvar start: Truncate from the start
|
||||||
|
:cvar end: Truncate from the end
|
||||||
|
"""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
start = "start"
|
||||||
|
end = "end"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTaskType(Enum):
|
||||||
|
"""How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
|
|
||||||
|
:cvar query: Used for a query for semantic search.
|
||||||
|
:cvar document: Used at indexing time when ingesting documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = "query"
|
||||||
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
@ -481,12 +506,18 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||||
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
|
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||||
|
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||||
|
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||||
|
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: Optional[InterleavedContent] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
error_code: Optional[int] = None
|
error_code: Optional[int] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
|
|
@ -19,6 +19,13 @@ def _get_model_size(model_dir):
|
||||||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_model_descriptor(model):
|
||||||
|
for m in all_registered_models():
|
||||||
|
if model == m.descriptor().replace(":", "-"):
|
||||||
|
return str(m.descriptor())
|
||||||
|
return str(model)
|
||||||
|
|
||||||
|
|
||||||
def _run_model_list_downloaded_cmd() -> None:
|
def _run_model_list_downloaded_cmd() -> None:
|
||||||
headers = ["Model", "Size", "Modified Time"]
|
headers = ["Model", "Size", "Modified Time"]
|
||||||
|
|
||||||
|
@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None:
|
||||||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
model,
|
_convert_to_model_descriptor(model),
|
||||||
model_size,
|
model_size,
|
||||||
modified_time,
|
modified_time,
|
||||||
]
|
]
|
||||||
|
@ -68,6 +75,13 @@ class ModelList(Subcommand):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="List the downloaded models",
|
help="List the downloaded models",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--search",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Search for the input string as a substring in the model descriptor(ID)",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
@ -87,15 +101,19 @@ class ModelList(Subcommand):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
descriptor = model.descriptor()
|
descriptor = model.descriptor()
|
||||||
rows.append(
|
if not args.search or args.search.lower() in descriptor.lower():
|
||||||
[
|
rows.append(
|
||||||
descriptor,
|
[
|
||||||
model.huggingface_repo,
|
descriptor,
|
||||||
f"{model.max_seq_length // 1024}K",
|
model.huggingface_repo,
|
||||||
]
|
f"{model.max_seq_length // 1024}K",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(rows) == 0:
|
||||||
|
print(f"Did not find any model matching `{args.search}`.")
|
||||||
|
else:
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
)
|
)
|
||||||
print_table(
|
|
||||||
rows,
|
|
||||||
headers,
|
|
||||||
separate_rows=True,
|
|
||||||
)
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
||||||
from llama_stack.cli.model.download import ModelDownload
|
from llama_stack.cli.model.download import ModelDownload
|
||||||
from llama_stack.cli.model.list import ModelList
|
from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
|
from llama_stack.cli.model.remove import ModelRemove
|
||||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
|
||||||
ModelPromptFormat.create(subparsers)
|
ModelPromptFormat.create(subparsers)
|
||||||
ModelDescribe.create(subparsers)
|
ModelDescribe.create(subparsers)
|
||||||
ModelVerifyDownload.create(subparsers)
|
ModelVerifyDownload.create(subparsers)
|
||||||
|
ModelRemove.create(subparsers)
|
||||||
|
|
67
llama_stack/cli/model/remove.py
Normal file
67
llama_stack/cli/model/remove.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRemove(Subcommand):
|
||||||
|
"""Remove the downloaded llama model"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"remove",
|
||||||
|
prog="llama model remove",
|
||||||
|
description="Remove the downloaded llama model",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_model_remove_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--model",
|
||||||
|
required=True,
|
||||||
|
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Used to forcefully remove the llama model from the storage without further confirmation",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
|
||||||
|
prompt_guard = prompt_guard_model_sku()
|
||||||
|
if args.model == prompt_guard.model_id:
|
||||||
|
model = prompt_guard
|
||||||
|
else:
|
||||||
|
model = resolve_model(args.model)
|
||||||
|
|
||||||
|
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
|
||||||
|
|
||||||
|
if model is None or not os.path.isdir(model_path):
|
||||||
|
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.force:
|
||||||
|
shutil.rmtree(model_path)
|
||||||
|
print(f"{args.model} removed.")
|
||||||
|
else:
|
||||||
|
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||||
|
shutil.rmtree(model_path)
|
||||||
|
print(f"{args.model} removed.")
|
||||||
|
else:
|
||||||
|
print("Removal aborted.")
|
|
@ -9,6 +9,7 @@ import importlib.resources
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -23,10 +24,10 @@ from termcolor import cprint
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.build import (
|
from llama_stack.distribution.build import (
|
||||||
SERVER_DEPENDENCIES,
|
SERVER_DEPENDENCIES,
|
||||||
ImageType,
|
|
||||||
build_image,
|
build_image,
|
||||||
get_provider_dependencies,
|
get_provider_dependencies,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BuildConfig,
|
BuildConfig,
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
|
@ -37,6 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
@ -59,8 +62,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if args.list_templates:
|
if args.list_templates:
|
||||||
return _run_template_list_cmd()
|
return _run_template_list_cmd()
|
||||||
|
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
if args.image_type == "venv":
|
||||||
image_name = args.image_name or current_conda_env
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
|
image_name = args.image_name or current_venv
|
||||||
|
if not image_name and in_notebook():
|
||||||
|
image_name = "__system__"
|
||||||
|
elif args.image_type == "conda":
|
||||||
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
|
image_name = args.image_name or current_conda_env
|
||||||
|
else:
|
||||||
|
image_name = args.image_name
|
||||||
|
|
||||||
if args.template:
|
if args.template:
|
||||||
available_templates = available_templates_specs()
|
available_templates = available_templates_specs()
|
||||||
|
@ -69,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
if args.image_type:
|
if args.image_type:
|
||||||
build_config.image_type = args.image_type
|
build_config.image_type = args.image_type
|
||||||
|
@ -78,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
elif not args.config and not args.template:
|
elif not args.config and not args.template:
|
||||||
name = prompt(
|
name = prompt(
|
||||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||||
|
@ -159,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||||
cprint(
|
cprint(
|
||||||
"Please specify --image-name when building a container from a config file",
|
"Please specify --image-name when building a container from a config file",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
if args.print_deps_only:
|
if args.print_deps_only:
|
||||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||||
|
@ -177,19 +188,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
print(f"uv pip install {special_dep}")
|
print(f"uv pip install {special_dep}")
|
||||||
return
|
return
|
||||||
|
|
||||||
_run_stack_build_command_from_build_config(
|
try:
|
||||||
build_config,
|
run_config = _run_stack_build_command_from_build_config(
|
||||||
image_name=image_name,
|
build_config,
|
||||||
config_path=args.config,
|
image_name=image_name,
|
||||||
template_name=args.template,
|
config_path=args.config,
|
||||||
)
|
template_name=args.template,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (Exception, RuntimeError) as exc:
|
||||||
|
cprint(
|
||||||
|
f"Error building stack: {exc}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if run_config is None:
|
||||||
|
cprint(
|
||||||
|
"Run config path is empty",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.run:
|
||||||
|
run_config = Path(run_config)
|
||||||
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
|
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||||
|
run_with_pty(run_args)
|
||||||
|
|
||||||
|
|
||||||
def _generate_run_config(
|
def _generate_run_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
"""
|
"""
|
||||||
|
@ -239,6 +272,7 @@ def _generate_run_config(
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
|
@ -246,7 +280,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: Optional[str] = None,
|
image_name: Optional[str] = None,
|
||||||
template_name: Optional[str] = None,
|
template_name: Optional[str] = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: Optional[str] = None,
|
||||||
) -> None:
|
) -> str:
|
||||||
if build_config.image_type == ImageType.container.value:
|
if build_config.image_type == ImageType.container.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
image_name = f"distribution-{template_name}"
|
image_name = f"distribution-{template_name}"
|
||||||
|
@ -256,6 +290,9 @@ def _run_stack_build_command_from_build_config(
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a conda image")
|
raise ValueError("Please specify an image name when building a conda image")
|
||||||
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
|
if not image_name:
|
||||||
|
raise ValueError("Please specify an image name when building a venv image")
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
build_dir = DISTRIBS_BASE_DIR / template_name
|
build_dir = DISTRIBS_BASE_DIR / template_name
|
||||||
|
@ -276,7 +313,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
template_or_config=template_name or config_path,
|
template_or_config=template_name or config_path,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
return
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
# copy run.yaml from template to build_dir instead of generating it again
|
# copy run.yaml from template to build_dir instead of generating it again
|
||||||
|
@ -286,8 +323,9 @@ def _run_stack_build_command_from_build_config(
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
|
return template_path
|
||||||
else:
|
else:
|
||||||
_generate_run_config(build_config, build_dir, image_name)
|
return _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
|
|
||||||
def _run_template_list_cmd() -> None:
|
def _run_template_list_cmd() -> None:
|
||||||
|
|
|
@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if
|
||||||
help="Print the dependencies for the stack only, without building the stack",
|
help="Print the dependencies for the stack only, without building the stack",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--run",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
# can be fast to load and reduces dependencies
|
# can be fast to load and reduces dependencies
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
|
|
||||||
class StackConfigure(Subcommand):
|
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"configure",
|
|
||||||
prog="llama stack configure",
|
|
||||||
description="Configure a llama stack distribution",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_stack_configure_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
self.parser.add_argument(
|
|
||||||
"config",
|
|
||||||
type=str,
|
|
||||||
help="Path to the build config file (e.g. ~/.llama/builds/<image_type>/<name>-build.yaml). For container, this could also be the name of the container image. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=str,
|
|
||||||
help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build/<image_type>/<name>-run.yaml",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
self.parser.error(
|
|
||||||
"""
|
|
||||||
DEPRECATED! llama stack configure has been deprecated.
|
|
||||||
Please use llama stack run <path/to/run.yaml> instead.
|
|
||||||
Please see example run.yaml in /distributions folder.
|
|
||||||
"""
|
|
||||||
)
|
|
|
@ -74,10 +74,6 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import importlib.resources
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -87,7 +83,7 @@ class StackRun(Subcommand):
|
||||||
BUILDS_BASE_DIR,
|
BUILDS_BASE_DIR,
|
||||||
DISTRIBS_BASE_DIR,
|
DISTRIBS_BASE_DIR,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
self.parser.error("Must specify a config file to run")
|
self.parser.error("Must specify a config file to run")
|
||||||
|
@ -125,64 +121,7 @@ class StackRun(Subcommand):
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if args.image_type == ImageType.container.value or config.container_image:
|
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
|
||||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
|
||||||
run_args = [script, image_name]
|
|
||||||
elif args.image_type == ImageType.conda.value:
|
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
|
||||||
image_name = args.image_name or current_conda_env
|
|
||||||
if not image_name:
|
|
||||||
cprint(
|
|
||||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_conda_prefix(env_name):
|
|
||||||
# Conda "base" environment does not end with "base" in the
|
|
||||||
# prefix, so should be handled separately.
|
|
||||||
if env_name == "base":
|
|
||||||
return os.environ.get("CONDA_PREFIX")
|
|
||||||
# Get conda environments info
|
|
||||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
|
||||||
envs = conda_env_info["envs"]
|
|
||||||
for envpath in envs:
|
|
||||||
if envpath.endswith(env_name):
|
|
||||||
return envpath
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Using conda environment: {image_name}")
|
|
||||||
conda_prefix = get_conda_prefix(image_name)
|
|
||||||
if not conda_prefix:
|
|
||||||
cprint(
|
|
||||||
f"Conda environment {image_name} does not exist.",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
|
||||||
if not build_file.exists():
|
|
||||||
cprint(
|
|
||||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
|
||||||
run_args = [
|
|
||||||
script,
|
|
||||||
image_name,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# else must be venv since that is the only valid option left.
|
|
||||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
|
||||||
venv = args.image_name or current_venv
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
|
||||||
run_args = [
|
|
||||||
script,
|
|
||||||
venv,
|
|
||||||
]
|
|
||||||
|
|
||||||
run_args.extend([str(config_file), str(args.port)])
|
run_args.extend([str(config_file), str(args.port)])
|
||||||
if args.disable_ipv6:
|
if args.disable_ipv6:
|
||||||
|
@ -206,5 +145,4 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if args.tls_keyfile and args.tls_certfile:
|
if args.tls_keyfile and args.tls_certfile:
|
||||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||||
|
|
||||||
run_with_pty(run_args)
|
run_with_pty(run_args)
|
||||||
|
|
|
@ -10,7 +10,6 @@ from importlib.metadata import version
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .configure import StackConfigure
|
|
||||||
from .list_apis import StackListApis
|
from .list_apis import StackListApis
|
||||||
from .list_providers import StackListProviders
|
from .list_providers import StackListProviders
|
||||||
from .run import StackRun
|
from .run import StackRun
|
||||||
|
@ -37,7 +36,6 @@ class StackParser(Subcommand):
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
StackBuild.create(subparsers)
|
StackBuild.create(subparsers)
|
||||||
StackConfigure.create(subparsers)
|
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ImageType(Enum):
|
|
||||||
container = "container"
|
|
||||||
conda = "conda"
|
|
||||||
venv = "venv"
|
|
||||||
|
|
||||||
|
|
||||||
class ApiInput(BaseModel):
|
class ApiInput(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
provider: str
|
provider: str
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
|
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
PYPI_VERSION=${PYPI_VERSION:-}
|
PYPI_VERSION=${PYPI_VERSION:-}
|
||||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||||
|
@ -32,7 +34,7 @@ container_base="$3"
|
||||||
build_file_path="$4"
|
build_file_path="$4"
|
||||||
host_build_dir="$5"
|
host_build_dir="$5"
|
||||||
pip_dependencies="$6"
|
pip_dependencies="$6"
|
||||||
special_pip_deps="$7"
|
special_pip_deps="${7:-}"
|
||||||
|
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
|
@ -106,26 +108,39 @@ fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
models_mount="/app/llama-models-source"
|
models_mount="/app/llama-models-source"
|
||||||
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
install_local_package() {
|
||||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
local dir="$1"
|
||||||
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
|
local mount_point="$2"
|
||||||
|
local name="$3"
|
||||||
|
|
||||||
|
if [ ! -d "$dir" ]; then
|
||||||
|
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install in editable format. We will mount the source code into the container
|
|
||||||
# so that changes will be reflected in the container without having to do a
|
|
||||||
# rebuild. This is just for development convenience.
|
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
COPY $LLAMA_STACK_DIR $stack_mount
|
COPY $dir $mount_point
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache -e $stack_mount
|
RUN uv pip install --no-cache -e $mount_point
|
||||||
EOF
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
|
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
|
||||||
else
|
else
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
|
@ -134,6 +149,7 @@ RUN uv pip install fastapi libcst
|
||||||
EOF
|
EOF
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||||
|
--index-strategy unsafe-best-match \
|
||||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
@ -149,23 +165,6 @@ EOF
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
|
||||||
add_to_container << EOF
|
|
||||||
COPY $LLAMA_MODELS_DIR $models_mount
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
add_to_container << EOF
|
|
||||||
RUN uv pip uninstall llama-models
|
|
||||||
RUN uv pip install --no-cache $models_mount
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
|
@ -177,6 +176,15 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Add other require item commands genearic to all containers
|
||||||
|
add_to_container << EOF
|
||||||
|
|
||||||
|
# Allows running as non-root user
|
||||||
|
RUN mkdir -p /.llama /.cache
|
||||||
|
|
||||||
|
RUN chmod -R g+rw /app /.llama /.cache
|
||||||
|
EOF
|
||||||
|
|
||||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
||||||
cat $TEMP_DIR/Containerfile
|
cat $TEMP_DIR/Containerfile
|
||||||
printf "\n"
|
printf "\n"
|
||||||
|
@ -189,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||||
fi
|
fi
|
||||||
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
|
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||||
|
|
|
@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||||
|
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
|
@ -24,8 +25,8 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 3 ]; then
|
if [ "$#" -lt 2 ]; then
|
||||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
@ -34,8 +35,7 @@ special_pip_deps="$3"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
build_name="$1"
|
env_name="$1"
|
||||||
env_name="llamastack-$build_name"
|
|
||||||
pip_dependencies="$2"
|
pip_dependencies="$2"
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
|
@ -74,9 +74,13 @@ run() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
local pip_dependencies="$2"
|
local pip_dependencies="$2"
|
||||||
local special_pip_deps="$3"
|
local special_pip_deps="$3"
|
||||||
|
|
||||||
if [ -n "$UV_SYSTEM_PYTHON" ]; then
|
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||||
echo "Installing dependencies in system Python environment"
|
echo "Installing dependencies in system Python environment"
|
||||||
|
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
|
||||||
|
export UV_SYSTEM_PYTHON=1
|
||||||
|
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||||
|
echo "Virtual environment $env_name is already active"
|
||||||
else
|
else
|
||||||
echo "Using virtual environment $env_name"
|
echo "Using virtual environment $env_name"
|
||||||
uv venv "$env_name"
|
uv venv "$env_name"
|
||||||
|
@ -90,6 +94,7 @@ run() {
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
# we are building a command line so word splitting is expected
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
|
--index-strategy unsafe-best-match \
|
||||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
|
|
@ -41,6 +41,7 @@ from llama_stack.distribution.stack import (
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
|
@ -52,19 +53,6 @@ logger = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def in_notebook():
|
|
||||||
try:
|
|
||||||
from IPython import get_ipython
|
|
||||||
|
|
||||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
|
||||||
return False
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
return value.value
|
return value.value
|
||||||
|
@ -230,12 +218,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
console = Console()
|
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
console = Console()
|
||||||
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
# Redact sensitive information before printing
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
endpoints = get_all_api_endpoints()
|
||||||
endpoint_impls = {}
|
endpoint_impls = {}
|
||||||
|
|
|
@ -6,7 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
BenchmarkConfig,
|
BenchmarkConfig,
|
||||||
|
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -214,7 +220,10 @@ class InferenceRouter(Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
output_dimension=output_dimension,
|
||||||
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
echo "Using virtual environment: $venv_path"
|
||||||
# Activate virtual environment
|
# Activate virtual environment
|
||||||
if [ ! -d "$venv_path" ]; then
|
if [ ! -d "$venv_path" ]; then
|
||||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||||
|
|
|
@ -12,8 +12,78 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
|
|
||||||
|
|
||||||
|
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
|
if image_type == ImageType.container.value or config.container_image:
|
||||||
|
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||||
|
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||||
|
run_args = [script, image_name]
|
||||||
|
elif image_type == ImageType.conda.value:
|
||||||
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
|
image_name = image_name or current_conda_env
|
||||||
|
if not image_name:
|
||||||
|
cprint(
|
||||||
|
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_conda_prefix(env_name):
|
||||||
|
# Conda "base" environment does not end with "base" in the
|
||||||
|
# prefix, so should be handled separately.
|
||||||
|
if env_name == "base":
|
||||||
|
return os.environ.get("CONDA_PREFIX")
|
||||||
|
# Get conda environments info
|
||||||
|
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||||
|
envs = conda_env_info["envs"]
|
||||||
|
for envpath in envs:
|
||||||
|
if envpath.endswith(env_name):
|
||||||
|
return envpath
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Using conda environment: {image_name}")
|
||||||
|
conda_prefix = get_conda_prefix(image_name)
|
||||||
|
if not conda_prefix:
|
||||||
|
cprint(
|
||||||
|
f"Conda environment {image_name} does not exist.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
||||||
|
if not build_file.exists():
|
||||||
|
cprint(
|
||||||
|
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
image_name,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# else must be venv since that is the only valid option left.
|
||||||
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
|
venv = image_name or current_venv
|
||||||
|
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
venv,
|
||||||
|
]
|
||||||
|
return run_args
|
||||||
|
|
||||||
|
|
||||||
def run_with_pty(command):
|
def run_with_pty(command):
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
|
@ -22,6 +92,19 @@ def run_with_pty(command):
|
||||||
return _run_with_pty_unix(command)
|
return _run_with_pty_unix(command)
|
||||||
|
|
||||||
|
|
||||||
|
def in_notebook():
|
||||||
|
try:
|
||||||
|
from IPython import get_ipython
|
||||||
|
|
||||||
|
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||||
|
return False
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# run a command in a pseudo-terminal, with interrupt handling,
|
# run a command in a pseudo-terminal, with interrupt handling,
|
||||||
# useful when you want to run interactive things
|
# useful when you want to run interactive things
|
||||||
def _run_with_pty_unix(command):
|
def _run_with_pty_unix(command):
|
||||||
|
|
13
llama_stack/distribution/utils/image_types.py
Normal file
13
llama_stack/distribution/utils/image_types.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ImageType(Enum):
|
||||||
|
container = "container"
|
||||||
|
conda = "conda"
|
||||||
|
venv = "venv"
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
|
@ -60,9 +62,13 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for turn in turns:
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
if output_message.tool_calls and request.allow_turn_resume:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn_messages = [
|
||||||
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the steps from the turn id
|
||||||
|
steps = []
|
||||||
|
if len(turns) > 0:
|
||||||
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
# we'll create a new tool execution step with current time
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in request.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=now,
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
turn=turn,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_message = None
|
||||||
|
async for chunk in self.run(
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
input_messages=messages,
|
||||||
|
sampling_params=self.agent_config.sampling_params,
|
||||||
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
@ -456,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=retrieved_context or [],
|
content=retrieved_context or [],
|
||||||
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -611,11 +743,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
# 1. Start the tool execution step and progress
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -625,6 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -639,6 +768,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# If tool is a builtin server tool, execute it
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
@ -650,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now()
|
tool_execution_start_time = datetime.now()
|
||||||
result_messages = await execute_tool_call_maybe(
|
tool_call = message.tool_calls[0]
|
||||||
|
tool_result = await execute_tool_call_maybe(
|
||||||
self.tool_runtime_api,
|
self.tool_runtime_api,
|
||||||
session_id,
|
session_id,
|
||||||
[message],
|
tool_call,
|
||||||
toolgroup_args,
|
toolgroup_args,
|
||||||
tool_to_group,
|
tool_to_group,
|
||||||
)
|
)
|
||||||
|
result_messages = [
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=tool_result.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||||
result_message = result_messages[0]
|
result_message = result_messages[0]
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
@ -675,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id=result_message.call_id,
|
call_id=result_message.call_id,
|
||||||
tool_name=result_message.tool_name,
|
tool_name=result_message.tool_name,
|
||||||
content=result_message.content,
|
content=result_message.content,
|
||||||
|
metadata=tool_result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
|
@ -913,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[CompletionMessage],
|
tool_call: ToolCall,
|
||||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||||
tool_to_group: Dict[str, str],
|
tool_to_group: Dict[str, str],
|
||||||
) -> List[ToolResponseMessage]:
|
) -> ToolInvocationResult:
|
||||||
# While Tools.run interface takes a list of messages,
|
|
||||||
# All tools currently only run on a single message
|
|
||||||
# When this changes, we can drop this assert
|
|
||||||
# Whether to call tools on each message and aggregate
|
|
||||||
# or aggregate and call tool once, reamins to be seen.
|
|
||||||
assert len(messages) == 1, "Expected single message"
|
|
||||||
message = messages[0]
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
group_name = tool_to_group.get(name, None)
|
group_name = tool_to_group.get(name, None)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
|
@ -946,14 +1092,7 @@ async def execute_tool_call_maybe(
|
||||||
**tool_call_args,
|
**tool_call_args,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
return [
|
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=result.content,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
|
|
|
@ -11,8 +11,6 @@ import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentCreateResponse,
|
AgentCreateResponse,
|
||||||
|
@ -21,6 +19,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
@ -68,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
# check if "bwrap" is available
|
# check if "bwrap" is available
|
||||||
if not shutil.which("bwrap"):
|
if not shutil.which("bwrap"):
|
||||||
print(
|
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
|
||||||
colored(
|
|
||||||
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
|
|
||||||
"yellow",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
|
@ -146,6 +140,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -155,6 +150,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
|
allow_turn_resume=allow_turn_resume,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
@ -169,6 +165,34 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = AgentTurnResumeRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnResumeRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.resume_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
|
@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> None:
|
||||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
|
|
@ -22,11 +22,14 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -230,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||||
|
chunks = chunks[: query_config.max_chunks]
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: query_config.max_chunks]:
|
for c in chunks:
|
||||||
metadata = c.metadata
|
metadata = c.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
metadata={
|
||||||
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
@ -4,26 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# config.py
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVectorIOConfig(BaseModel):
|
class SQLiteVectorIOConfig(BaseModel):
|
||||||
db_path: str
|
db_path: str
|
||||||
kvstore: KVStoreConfig
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||||
__distro_dir__=__distro_dir__,
|
|
||||||
db_name="sqlite_vec.db",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
pip_packages=["sentence-transformers"],
|
pip_packages=[
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
],
|
||||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[],
|
pip_packages=[
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"pypdf",
|
||||||
|
"tqdm",
|
||||||
|
"numpy",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"nltk",
|
||||||
|
"sentencepiece",
|
||||||
|
"transformers",
|
||||||
|
],
|
||||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||||
api_dependencies=[Api.vector_io, Api.inference],
|
api_dependencies=[Api.vector_io, Api.inference],
|
||||||
|
|
|
@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import (
|
||||||
remote_provider_spec,
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
EMBEDDING_DEPS = [
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"pypdf",
|
|
||||||
"tqdm",
|
|
||||||
"numpy",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"nltk",
|
|
||||||
"sentencepiece",
|
|
||||||
"transformers",
|
|
||||||
# this happens to work because special dependencies are always installed last
|
|
||||||
# so if there was a regular torch installed first, this would be ignored
|
|
||||||
# we need a better way to do this to identify potential conflicts, etc.
|
|
||||||
# for now, this lets us significantly reduce the size of the container which
|
|
||||||
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
|
pip_packages=["sqlite-vec"],
|
||||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
provider_type="inline::sqlite_vec",
|
||||||
|
pip_packages=["sqlite-vec"],
|
||||||
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
|
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::chromadb",
|
provider_type="inline::chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
pip_packages=["chromadb"],
|
||||||
module="llama_stack.providers.inline.vector_io.chroma",
|
module="llama_stack.providers.inline.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="pgvector",
|
adapter_type="pgvector",
|
||||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
pip_packages=["psycopg2-binary"],
|
||||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
pip_packages=["weaviate-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
|
@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="qdrant",
|
adapter_type="qdrant",
|
||||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
pip_packages=["qdrant-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -162,7 +167,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -172,6 +177,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
@ -130,7 +135,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -204,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
@ -232,7 +236,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -17,16 +17,23 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import CoreModelId
|
from llama_stack.models.llama.sku_list import CoreModelId
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -140,7 +147,10 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,6 +48,14 @@ _MODEL_ENTRIES = [
|
||||||
"meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="baai/bge-m3",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 1024,
|
||||||
|
"context_length": 8192,
|
||||||
|
},
|
||||||
|
),
|
||||||
# TODO(mf): how do we handle Nemotron models?
|
# TODO(mf): how do we handle Nemotron models?
|
||||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -18,15 +23,20 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
InterleavedContent,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -117,9 +127,41 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
if any(content_has_media(content) for content in contents):
|
||||||
|
raise NotImplementedError("Media is not supported")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
|
||||||
|
# ->
|
||||||
|
# OpenAI: input = str | List[str]
|
||||||
|
#
|
||||||
|
# we can ignore str and always pass List[str] to OpenAI
|
||||||
|
#
|
||||||
|
flat_contents = [
|
||||||
|
item.text if isinstance(item, TextContentItem) else item
|
||||||
|
for content in contents
|
||||||
|
for item in (content if isinstance(content, list) else [content])
|
||||||
|
]
|
||||||
|
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||||
|
model = self.get_provider_model_id(model_id)
|
||||||
|
|
||||||
|
response = await self._client.embeddings.create(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
||||||
|
# ->
|
||||||
|
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
|
||||||
|
#
|
||||||
|
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ollama import AsyncClient
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -20,11 +21,13 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -175,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.register_helper.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||||
# flatten the list of lists
|
# flatten the list of lists
|
||||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||||
|
@ -184,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
llama_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Ollama does not support media for Completion requests"
|
assert not media_present, "Ollama does not support media for Completion requests"
|
||||||
|
@ -258,7 +262,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
@ -274,7 +281,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
|
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
|
await self.client.pull(model.provider_resource_id)
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
else:
|
else:
|
||||||
response = await self.client.ps()
|
response = await self.client.ps()
|
||||||
|
@ -284,7 +294,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.register_helper.register_model(model)
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
|
|
|
@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[InterleavedContent],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
return client.inference.embeddings(
|
return client.inference.embeddings(
|
||||||
model_id=model.provider_resource_id,
|
model_id=model.provider_resource_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
output_dimension=output_dimension,
|
||||||
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -119,6 +120,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -5,16 +5,38 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionMessage,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
SystemMessage,
|
||||||
|
TextTruncation,
|
||||||
|
ToolCall,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
|
@ -119,7 +141,10 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -268,7 +273,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -198,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
@ -219,7 +223,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
|
|
|
@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.datatypes import StopReason, ToolCall
|
from llama_models.datatypes import StopReason, ToolCall
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -22,18 +28,21 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -112,10 +121,16 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
||||||
if tool_param.required:
|
if tool_param.required:
|
||||||
compat_required.append(tool_key)
|
compat_required.append(tool_key)
|
||||||
|
|
||||||
|
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||||
|
# it's the latter, convert to a string.
|
||||||
|
tool_name = tool.tool_name
|
||||||
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
tool_name = tool_name.value
|
||||||
|
|
||||||
compat_tool = {
|
compat_tool = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool.tool_name,
|
"name": tool_name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -369,7 +384,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,10 @@ import unittest
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
CompletionMessage,
|
||||||
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
|
@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
|
|
||||||
|
async def test_completion_message_encoding(self):
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL3_2,
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="hello"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
tool_name="custom1",
|
||||||
|
arguments={"param1": "value1"},
|
||||||
|
call_id="123",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="custom1",
|
||||||
|
description="custom1 tool",
|
||||||
|
parameters={
|
||||||
|
"param1": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="param1 description",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||||
|
)
|
||||||
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||||
|
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||||
|
|
||||||
|
request.model = MODEL
|
||||||
|
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||||
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||||
|
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
||||||
|
|
||||||
async def test_user_provided_system_message(self):
|
async def test_user_provided_system_message(self):
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
|
|
|
@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture:
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="sqlite_vec",
|
provider_id="sqlite_vec",
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig(
|
config=SQLiteVectorIOConfig(
|
||||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
|
|
|
@ -5,13 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
InterleavedContent,
|
EmbeddingTaskType,
|
||||||
|
InterleavedContentItem,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
|
TextTruncation,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
@ -25,11 +28,16 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(contents)
|
embeddings = embedding_model.encode(
|
||||||
|
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||||
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
else:
|
else:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
if model.metadata.get("llama_model") is None:
|
llama_model = model.metadata.get("llama_model")
|
||||||
raise ValueError(
|
if llama_model is None:
|
||||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
return model
|
||||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
|
||||||
)
|
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
if existing_llama_model:
|
if existing_llama_model:
|
||||||
if existing_llama_model != model.metadata["llama_model"]:
|
if existing_llama_model != llama_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||||
)
|
)
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
model_input = formatter.encode_dialog_prompt(
|
||||||
|
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||||
|
)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info(
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
model_input = formatter.encode_dialog_prompt(
|
||||||
|
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
len(model_input.tokens),
|
len(model_input.tokens),
|
||||||
|
|
|
@ -5,12 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, List, Optional, TypeVar
|
from typing import Any, Callable, List, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebMethod:
|
class WebMethod:
|
||||||
|
@ -22,6 +20,13 @@ class WebMethod:
|
||||||
raw_bytes_request_body: Optional[bool] = False
|
raw_bytes_request_body: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class HasWebMethod(Protocol):
|
||||||
|
__webmethod__: WebMethod
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
|
||||||
|
|
||||||
|
|
||||||
def webmethod(
|
def webmethod(
|
||||||
route: Optional[str] = None,
|
route: Optional[str] = None,
|
||||||
method: Optional[str] = None,
|
method: Optional[str] = None,
|
||||||
|
|
|
@ -11,7 +11,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterable
|
||||||
|
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class ChangedPathTracker:
|
||||||
return self._changed_paths
|
return self._changed_paths
|
||||||
|
|
||||||
|
|
||||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
|
||||||
"""Find immediate subdirectories in the templates folder."""
|
"""Find immediate subdirectories in the templates folder."""
|
||||||
if not templates_dir.exists():
|
if not templates_dir.exists():
|
||||||
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
||||||
|
@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
||||||
return has_changes
|
return has_changes
|
||||||
|
|
||||||
|
|
||||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
|
||||||
try:
|
try:
|
||||||
module_name = f"llama_stack.templates.{template_dir.name}"
|
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
|
|
|
@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
|
||||||
pytest_args,
|
pytest_args,
|
||||||
"-s",
|
"-s",
|
||||||
"-v",
|
"-v",
|
||||||
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
|
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::cerebras
|
- remote::cerebras
|
||||||
|
- inline::sentence-transformers
|
||||||
safety:
|
safety:
|
||||||
- inline::llama-guard
|
- inline::llama-guard
|
||||||
vector_io:
|
vector_io:
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::cerebras"],
|
"inference": ["remote::cerebras", "inline::sentence-transformers"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
@ -5,6 +5,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::tgi
|
- remote::tgi
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::tgi"],
|
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::fireworks
|
- remote::fireworks
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::fireworks"],
|
"inference": ["remote::fireworks", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::hf::serverless
|
- remote::hf::serverless
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::hf::serverless"],
|
"inference": ["remote::hf::serverless", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
@ -55,9 +55,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||||
default_models = [
|
default_models = [
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=core_model_to_hf_repo[m.llama_model],
|
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||||
provider_model_id=m.provider_model_id,
|
provider_model_id=m.provider_model_id,
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
|
model_type=m.model_type,
|
||||||
|
metadata=m.metadata,
|
||||||
)
|
)
|
||||||
for m in _MODEL_ENTRIES
|
for m in _MODEL_ENTRIES
|
||||||
]
|
]
|
||||||
|
|
|
@ -135,6 +135,13 @@ models:
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1024
|
||||||
|
context_length: 8192
|
||||||
|
model_id: baai/bge-m3
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: baai/bge-m3
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -5,7 +5,7 @@ distribution_spec:
|
||||||
inference:
|
inference:
|
||||||
- remote::ollama
|
- remote::ollama
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::sqlite-vec
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
safety:
|
safety:
|
||||||
|
|
|
@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
||||||
SentenceTransformersInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::ollama"],
|
"inference": ["remote::ollama"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::ollama",
|
provider_type="remote::ollama",
|
||||||
config=OllamaImplConfig.sample_run_config(),
|
config=OllamaImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
embedding_provider = Provider(
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
provider_type="inline::sentence-transformers",
|
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
|
||||||
)
|
|
||||||
vector_io_provider_faiss = Provider(
|
|
||||||
provider_id="faiss",
|
|
||||||
provider_type="inline::faiss",
|
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
|
||||||
)
|
|
||||||
vector_io_provider_sqlite = Provider(
|
vector_io_provider_sqlite = Provider(
|
||||||
provider_id="sqlite_vec",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider],
|
||||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
|
"vector_io": [vector_io_provider_sqlite],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, embedding_model],
|
default_models=[inference_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [inference_provider],
|
||||||
inference_provider,
|
"vector_io": [vector_io_provider_sqlite],
|
||||||
embedding_provider,
|
|
||||||
],
|
|
||||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
|
|
||||||
"safety": [
|
"safety": [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="llama-guard",
|
provider_id="llama-guard",
|
||||||
|
|
|
@ -16,24 +16,11 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
config: {}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::faiss
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
kvstore:
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
- provider_id: faiss
|
|
||||||
provider_type: inline::faiss
|
|
||||||
config:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -16,24 +16,11 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
config: {}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::faiss
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
kvstore:
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
|
||||||
- provider_id: sqlite_vec
|
|
||||||
provider_type: inline::sqlite_vec
|
|
||||||
config:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
@ -100,11 +87,14 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
<<<<<<< HEAD
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
model_id: all-MiniLM-L6-v2
|
model_id: all-MiniLM-L6-v2
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
|
=======
|
||||||
|
>>>>>>> upstream/main
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::vllm
|
- remote::vllm
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::vllm"],
|
"inference": ["remote::vllm", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::tgi
|
- remote::tgi
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::tgi"],
|
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue