mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider
This commit is contained in:
commit
ca6a12e362
114 changed files with 2100 additions and 685 deletions
31
.github/workflows/update-readthedocs.yml
vendored
31
.github/workflows/update-readthedocs.yml
vendored
|
@ -11,17 +11,42 @@ on:
|
|||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/source/**'
|
||||
- 'docs/resources/**'
|
||||
- 'docs/**'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
|
||||
jobs:
|
||||
update-readthedocs:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Sync with uv
|
||||
run: uv sync --extra docs
|
||||
|
||||
- name: Build HTML
|
||||
run: |
|
||||
cd docs
|
||||
uv run make html
|
||||
|
||||
- name: Trigger ReadTheDocs build
|
||||
if: github.event_name != 'pull_request'
|
||||
run: |
|
||||
if [ -z "$TOKEN" ]; then
|
||||
echo "READTHEDOCS_TOKEN is not set"
|
||||
|
|
|
@ -30,6 +30,7 @@ repos:
|
|||
rev: v0.9.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
exclude: ^llama_stack/strong_typing/.*$
|
||||
- id: ruff-format
|
||||
|
||||
|
@ -45,23 +46,26 @@ repos:
|
|||
hooks:
|
||||
- id: uv-export
|
||||
args: [
|
||||
"--frozen",
|
||||
"--no-hashes",
|
||||
"--no-emit-project",
|
||||
"--frozen",
|
||||
"--no-hashes",
|
||||
"--no-emit-project",
|
||||
"--output-file=requirements.txt"
|
||||
]
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-sync
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.14.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# additional_dependencies:
|
||||
# - types-requests
|
||||
# - types-setuptools
|
||||
# - pydantic
|
||||
# args: [--ignore-missing-imports]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- uv==0.6.2
|
||||
- mypy
|
||||
- pytest
|
||||
- rich
|
||||
- types-requests
|
||||
- pydantic
|
||||
pass_filenames: false
|
||||
|
||||
# - repo: https://github.com/jsh9/pydoclint
|
||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||
|
|
|
@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
|
|||
$ cd llama-stack/docs
|
||||
$ uv sync --extra docs
|
||||
|
||||
# This rebuilds the documentation pages.
|
||||
$ uv run make html
|
||||
|
||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||
$ make html
|
||||
$ uv run sphinx-autobuild source build/html
|
||||
$ uv run sphinx-autobuild source build/html --write-all
|
||||
```
|
||||
|
||||
### Update API Documentation
|
||||
|
@ -145,7 +147,7 @@ If you modify or add new API endpoints, update the API documentation accordingly
|
|||
|
||||
```bash
|
||||
$ uv sync --extra dev
|
||||
$ ./docs/openapi_generator/run_openapi_generator.sh
|
||||
$ uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
|
|
@ -3,3 +3,4 @@ include distributions/dependencies.json
|
|||
include llama_stack/distribution/*.sh
|
||||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/templates/*/*.yaml
|
||||
include llama_stack/providers/tests/test_cases/*.json
|
||||
|
|
12
README.md
12
README.md
|
@ -78,18 +78,14 @@ You have two ways to install this repository:
|
|||
```
|
||||
|
||||
* **Install from source**:
|
||||
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
||||
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
|
||||
Then, run the following commands:
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
|
||||
conda create -n stack python=3.10
|
||||
conda activate stack
|
||||
|
||||
cd llama-stack
|
||||
pip install -e .
|
||||
|
||||
uv sync
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
|
|
@ -30,9 +30,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"uvicorn"
|
||||
],
|
||||
"cerebras": [
|
||||
"aiosqlite",
|
||||
|
@ -170,9 +168,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"uvicorn"
|
||||
],
|
||||
"hf-serverless": [
|
||||
"aiohttp",
|
||||
|
@ -247,9 +243,7 @@
|
|||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"zmq"
|
||||
],
|
||||
"meta-reference-quantized-gpu": [
|
||||
"accelerate",
|
||||
|
@ -290,9 +284,7 @@
|
|||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"zmq"
|
||||
],
|
||||
"nvidia": [
|
||||
"aiosqlite",
|
||||
|
@ -323,9 +315,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
"aiohttp",
|
||||
|
@ -335,7 +325,6 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
@ -356,11 +345,10 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"uvicorn"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
|
@ -423,9 +411,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
"uvicorn"
|
||||
],
|
||||
"tgi": [
|
||||
"aiohttp",
|
||||
|
|
232
docs/_static/llama-stack-spec.html
vendored
232
docs/_static/llama-stack-spec.html
vendored
|
@ -2315,6 +2315,70 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Turn"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the agent to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the session to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "turn_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the turn to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -4226,6 +4290,9 @@
|
|||
},
|
||||
"tool_config": {
|
||||
"$ref": "#/components/schemas/ToolConfig"
|
||||
},
|
||||
"allow_turn_resume": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -4454,6 +4521,31 @@
|
|||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -4612,6 +4704,9 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
|
@ -4621,7 +4716,8 @@
|
|||
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
||||
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
||||
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
||||
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
|
||||
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -4784,6 +4880,25 @@
|
|||
"title": "AgentTurnResponseStreamChunk",
|
||||
"description": "streamed agent turn completion response."
|
||||
},
|
||||
"AgentTurnResponseTurnAwaitingInputPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"event_type": {
|
||||
"type": "string",
|
||||
"const": "turn_awaiting_input",
|
||||
"default": "turn_awaiting_input"
|
||||
},
|
||||
"turn": {
|
||||
"$ref": "#/components/schemas/Turn"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"event_type",
|
||||
"turn"
|
||||
],
|
||||
"title": "AgentTurnResponseTurnAwaitingInputPayload"
|
||||
},
|
||||
"AgentTurnResponseTurnCompletePayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -4929,11 +5044,42 @@
|
|||
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"contents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContent"
|
||||
},
|
||||
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||
},
|
||||
"text_truncation": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"none",
|
||||
"start",
|
||||
"end"
|
||||
],
|
||||
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
|
||||
},
|
||||
"output_dimension": {
|
||||
"type": "integer",
|
||||
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
|
||||
},
|
||||
"task_type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"query",
|
||||
"document"
|
||||
],
|
||||
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -6625,6 +6771,31 @@
|
|||
},
|
||||
"error_code": {
|
||||
"type": "integer"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -7474,9 +7645,37 @@
|
|||
"properties": {
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"metadata"
|
||||
],
|
||||
"title": "RAGQueryResult"
|
||||
},
|
||||
"QueryChunksRequest": {
|
||||
|
@ -8015,6 +8214,27 @@
|
|||
],
|
||||
"title": "RegisterVectorDbRequest"
|
||||
},
|
||||
"ResumeAgentTurnRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool_responses": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||
},
|
||||
"description": "The tool call responses to resume the turn with."
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to stream the response."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"tool_responses"
|
||||
],
|
||||
"title": "ResumeAgentTurnRequest"
|
||||
},
|
||||
"RunEvalRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
152
docs/_static/llama-stack-spec.yaml
vendored
152
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1401,6 +1401,53 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/QueryTracesRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||
objects.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Turn'
|
||||
text/event-stream:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||
tags:
|
||||
- Agents
|
||||
description: >-
|
||||
Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client
|
||||
side tool calls, this endpoint can be used to submit the outputs from the
|
||||
tool calls once they are ready.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: The ID of the agent to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: session_id
|
||||
in: path
|
||||
description: The ID of the session to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: turn_id
|
||||
in: path
|
||||
description: The ID of the turn to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ResumeAgentTurnRequest'
|
||||
required: true
|
||||
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
||||
post:
|
||||
responses:
|
||||
|
@ -2740,6 +2787,8 @@ components:
|
|||
$ref: '#/components/schemas/AgentTool'
|
||||
tool_config:
|
||||
$ref: '#/components/schemas/ToolConfig'
|
||||
allow_turn_resume:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
required:
|
||||
- messages
|
||||
|
@ -2896,6 +2945,16 @@ components:
|
|||
- type: string
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- call_id
|
||||
|
@ -2992,6 +3051,7 @@ components:
|
|||
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||
discriminator:
|
||||
propertyName: event_type
|
||||
mapping:
|
||||
|
@ -3000,6 +3060,7 @@ components:
|
|||
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||
AgentTurnResponseStepCompletePayload:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -3106,6 +3167,21 @@ components:
|
|||
- event
|
||||
title: AgentTurnResponseStreamChunk
|
||||
description: streamed agent turn completion response.
|
||||
"AgentTurnResponseTurnAwaitingInputPayload":
|
||||
type: object
|
||||
properties:
|
||||
event_type:
|
||||
type: string
|
||||
const: turn_awaiting_input
|
||||
default: turn_awaiting_input
|
||||
turn:
|
||||
$ref: '#/components/schemas/Turn'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- event_type
|
||||
- turn
|
||||
title: >-
|
||||
AgentTurnResponseTurnAwaitingInputPayload
|
||||
AgentTurnResponseTurnCompletePayload:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -3224,13 +3300,39 @@ components:
|
|||
The identifier of the model to use. The model must be an embedding model
|
||||
registered with Llama Stack and available via the /models endpoint.
|
||||
contents:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
description: >-
|
||||
List of contents to generate embeddings for. Note that content can be
|
||||
multimodal. The behavior depends on the model and provider. Some models
|
||||
may only support text.
|
||||
List of contents to generate embeddings for. Each content can be a string
|
||||
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||
depends on the model and provider. Some models may only support text.
|
||||
text_truncation:
|
||||
type: string
|
||||
enum:
|
||||
- none
|
||||
- start
|
||||
- end
|
||||
description: >-
|
||||
(Optional) Config for how to truncate text for embedding when text is
|
||||
longer than the model's max sequence length.
|
||||
output_dimension:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Output dimensionality for the embeddings. Only supported by
|
||||
Matryoshka models.
|
||||
task_type:
|
||||
type: string
|
||||
enum:
|
||||
- query
|
||||
- document
|
||||
description: >-
|
||||
(Optional) How is the embedding being used? This is only supported by
|
||||
asymmetric embedding models.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
|
@ -4289,6 +4391,16 @@ components:
|
|||
type: string
|
||||
error_code:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
|
@ -4862,7 +4974,19 @@ components:
|
|||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- metadata
|
||||
title: RAGQueryResult
|
||||
QueryChunksRequest:
|
||||
type: object
|
||||
|
@ -5179,6 +5303,22 @@ components:
|
|||
- vector_db_id
|
||||
- embedding_model
|
||||
title: RegisterVectorDbRequest
|
||||
ResumeAgentTurnRequest:
|
||||
type: object
|
||||
properties:
|
||||
tool_responses:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponseMessage'
|
||||
description: >-
|
||||
The tool call responses to resume the turn with.
|
||||
stream:
|
||||
type: boolean
|
||||
description: Whether to stream the response.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- tool_responses
|
||||
title: ResumeAgentTurnRequest
|
||||
RunEvalRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -86,8 +86,6 @@
|
|||
"# NBVAL_SKIP\n",
|
||||
"\n",
|
||||
"!apt-get install -y bubblewrap\n",
|
||||
"import os\n",
|
||||
"os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n",
|
||||
"!pip install uv\n",
|
||||
"!uv pip install llama-stack"
|
||||
]
|
||||
|
@ -3632,7 +3630,7 @@
|
|||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "master",
|
||||
"display_name": "toolchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
|
@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
|||
## API Providers
|
||||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.),
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
|
@ -33,7 +33,7 @@ Providers come in two flavors:
|
|||
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||
|
||||
Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally.
|
||||
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||
## Resources
|
||||
|
||||
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from docutils import nodes
|
||||
|
||||
project = "llama-stack"
|
||||
copyright = "2024, Meta"
|
||||
copyright = "2025, Meta"
|
||||
author = "Meta"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
|
|
@ -38,6 +38,7 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
|
||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
|
||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
|
||||
- `baai/bge-m3 (baai/bge-m3)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
|||
- {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia))
|
||||
|
||||
- **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs.
|
||||
- {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama))
|
||||
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
||||
|
||||
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
||||
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
||||
|
@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r
|
|||
- [Android](ondevice_distro/android_sdk)
|
||||
|
||||
|
||||
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).**
|
||||
- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).**
|
||||
|
||||
### Distribution Details
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::cerebras` |
|
||||
| inference | `remote::cerebras`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::tgi` |
|
||||
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::fireworks` |
|
||||
| inference | `remote::fireworks`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.
|
||||
|
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::vllm` |
|
||||
| inference | `remote::vllm`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::tgi` |
|
||||
| inference | `remote::tgi`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::together` |
|
||||
| inference | `remote::together`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
|||
| **Provider** | **Environments** |
|
||||
| :----: | :----: |
|
||||
| FAISS | Single Node |
|
||||
| SQLite-Vec| Single Node |
|
||||
| Chroma | Hosted and Single Node |
|
||||
| Postgres (PGVector) | Hosted and Single Node |
|
||||
| Weaviate | Hosted |
|
||||
|
@ -88,6 +89,7 @@ self
|
|||
introduction/index
|
||||
getting_started/index
|
||||
concepts/index
|
||||
providers/index
|
||||
distributions/index
|
||||
distributions/selection
|
||||
building_applications/index
|
||||
|
|
59
docs/source/providers/index.md
Normal file
59
docs/source/providers/index.md
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Providers Overview
|
||||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
Providers come in two flavors:
|
||||
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||
|
||||
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||
|
||||
## Agents
|
||||
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||
|
||||
## DatasetIO
|
||||
Interfaces with datasets and data loaders.
|
||||
|
||||
## Eval
|
||||
Generates outputs (via Inference or Agents) and perform scoring.
|
||||
|
||||
## Inference
|
||||
Runs inference with an LLM.
|
||||
|
||||
## Post Training
|
||||
Fine-tunes a model.
|
||||
|
||||
## Safety
|
||||
Applies safety policies to the output at a Systems (not only model) level.
|
||||
|
||||
## Scoring
|
||||
Evaluates the outputs of the system.
|
||||
|
||||
## Telemetry
|
||||
Collects telemetry data from the system.
|
||||
|
||||
## Tool Runtime
|
||||
Is associated with the ToolGroup resouces.
|
||||
|
||||
## Vector IO
|
||||
|
||||
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
|
||||
Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector
|
||||
io and database are used to store and retrieve documents for retrieval.
|
||||
|
||||
#### Vector IO Providers
|
||||
The following providers (i.e., databases) are available for Vector IO:
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
vector_io/faiss
|
||||
vector_io/sqlite-vec
|
||||
vector_io/chromadb
|
||||
vector_io/pgvector
|
||||
vector_io/qdrant
|
||||
vector_io/weaviate
|
||||
```
|
36
docs/source/providers/vector_io/chromadb.md
Normal file
36
docs/source/providers/vector_io/chromadb.md
Normal file
|
@ -0,0 +1,36 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Chroma
|
||||
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
||||
## Features
|
||||
Chroma supports:
|
||||
- Store embeddings and their metadata
|
||||
- Vector search
|
||||
- Full-text search
|
||||
- Document storage
|
||||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
## Usage
|
||||
|
||||
To use Chrome in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use chroma.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install chroma using pip:
|
||||
|
||||
```bash
|
||||
pip install chromadb
|
||||
```
|
||||
|
||||
## Documentation
|
||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
33
docs/source/providers/vector_io/faiss.md
Normal file
33
docs/source/providers/vector_io/faiss.md
Normal file
|
@ -0,0 +1,33 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Faiss
|
||||
|
||||
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
||||
## Features
|
||||
|
||||
- Lightweight and easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
- GPU support
|
||||
|
||||
## Usage
|
||||
|
||||
To use Faiss in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install Faiss using pip:
|
||||
|
||||
```bash
|
||||
pip install faiss-cpu
|
||||
```
|
||||
## Documentation
|
||||
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
|
||||
more details about Faiss in general.
|
31
docs/source/providers/vector_io/pgvector.md
Normal file
31
docs/source/providers/vector_io/pgvector.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Postgres PGVector
|
||||
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install PGVector using docker:
|
||||
|
||||
```bash
|
||||
docker pull pgvector/pgvector:pg17
|
||||
```
|
||||
## Documentation
|
||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
31
docs/source/providers/vector_io/qdrant.md
Normal file
31
docs/source/providers/vector_io/qdrant.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Qdrant
|
||||
|
||||
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use Qdrant in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install Qdrant using docker:
|
||||
|
||||
```bash
|
||||
docker pull qdrant/qdrant
|
||||
```
|
||||
## Documentation
|
||||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
33
docs/source/providers/vector_io/sqlite-vec.md
Normal file
|
@ -0,0 +1,33 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# SQLite-Vec
|
||||
|
||||
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within an SQLite database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
||||
## Features
|
||||
|
||||
- Lightweight and easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use SQLite-Vec in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install SQLite-Vec using pip:
|
||||
|
||||
```bash
|
||||
pip install sqlite-vec
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
33
docs/source/providers/vector_io/weaviate.md
Normal file
33
docs/source/providers/vector_io/weaviate.md
Normal file
|
@ -0,0 +1,33 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Weaviate
|
||||
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
||||
## Features
|
||||
Weaviate supports:
|
||||
- Store embeddings and their metadata
|
||||
- Vector search
|
||||
- Full-text search
|
||||
- Hybrid search
|
||||
- Document storage
|
||||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
## Usage
|
||||
|
||||
To use Weaviate in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use chroma.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
|
||||
|
||||
## Documentation
|
||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
|
@ -171,7 +171,7 @@ The `llama model` command helps you explore the model’s interface.
|
|||
llama model --help
|
||||
```
|
||||
```
|
||||
usage: llama model [-h] {download,list,prompt-format,describe} ...
|
||||
usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ...
|
||||
|
||||
Work with llama models
|
||||
|
||||
|
@ -179,15 +179,15 @@ options:
|
|||
-h, --help show this help message and exit
|
||||
|
||||
model_subcommands:
|
||||
{download,list,prompt-format,describe}
|
||||
{download,list,prompt-format,describe,verify-download,remove}
|
||||
```
|
||||
|
||||
### Describe
|
||||
|
||||
You can use the describe command to know more about a model:
|
||||
```
|
||||
llama model describe -m Llama3.2-3B-Instruct
|
||||
```
|
||||
### Describe
|
||||
|
||||
```
|
||||
+-----------------------------+----------------------------------+
|
||||
| Model | Llama3.2-3B-Instruct |
|
||||
|
@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct
|
|||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||
|
||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||
|
||||
### Remove model
|
||||
You can run `llama model remove` to remove unecessary model:
|
||||
|
||||
```
|
||||
llama model remove -m Llama-Guard-3-8B-int8
|
||||
```
|
||||
|
|
|
@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
|
|||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
turn_awaiting_input = "turn_awaiting_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
|
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
|
|||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
],
|
||||
|
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||
allow_turn_resume: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponseMessage]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
|
@ -333,8 +355,34 @@ class Agents(Protocol):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
||||
:param agent_id: The ID of the agent to resume.
|
||||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
|
|
|
@ -91,15 +91,18 @@ ParamType = register_schema(
|
|||
name="ParamType",
|
||||
)
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
# will cause infinite recursion in OpenAPI generation script
|
||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||
# ArrayType.model_rebuild()
|
||||
# ObjectType.model_rebuild()
|
||||
# UnionType.model_rebuild()
|
||||
ArrayType.model_rebuild()
|
||||
ObjectType.model_rebuild()
|
||||
UnionType.model_rebuild()
|
||||
|
||||
|
||||
# class CustomType(BaseModel):
|
||||
# type: Literal["custom"] = "custom"
|
||||
# validator_class: str
|
||||
class CustomType(BaseModel):
|
||||
pylint: disable=syntax-error
|
||||
type: Literal["custom"] = "custom"
|
||||
validator_class: str
|
||||
"""
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import (
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
|
|||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -402,6 +403,30 @@ class ModelStore(Protocol):
|
|||
def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||
|
||||
:cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
|
||||
:cvar start: Truncate from the start
|
||||
:cvar end: Truncate from the end
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
start = "start"
|
||||
end = "end"
|
||||
|
||||
|
||||
class EmbeddingTaskType(Enum):
|
||||
"""How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
|
||||
:cvar query: Used for a query for semantic search.
|
||||
:cvar document: Used at indexing time when ingesting documents.
|
||||
"""
|
||||
|
||||
query = "query"
|
||||
document = "document"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
|
@ -481,12 +506,18 @@ class Inference(Protocol):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
|
||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
|
|||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
|
|||
content: InterleavedContent
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
|
|
|
@ -19,6 +19,13 @@ def _get_model_size(model_dir):
|
|||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _convert_to_model_descriptor(model):
|
||||
for m in all_registered_models():
|
||||
if model == m.descriptor().replace(":", "-"):
|
||||
return str(m.descriptor())
|
||||
return str(model)
|
||||
|
||||
|
||||
def _run_model_list_downloaded_cmd() -> None:
|
||||
headers = ["Model", "Size", "Modified Time"]
|
||||
|
||||
|
@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None:
|
|||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||
rows.append(
|
||||
[
|
||||
model,
|
||||
_convert_to_model_descriptor(model),
|
||||
model_size,
|
||||
modified_time,
|
||||
]
|
||||
|
@ -68,6 +75,13 @@ class ModelList(Subcommand):
|
|||
action="store_true",
|
||||
help="List the downloaded models",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-s",
|
||||
"--search",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Search for the input string as a substring in the model descriptor(ID)",
|
||||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
|
@ -87,15 +101,19 @@ class ModelList(Subcommand):
|
|||
continue
|
||||
|
||||
descriptor = model.descriptor()
|
||||
rows.append(
|
||||
[
|
||||
descriptor,
|
||||
model.huggingface_repo,
|
||||
f"{model.max_seq_length // 1024}K",
|
||||
]
|
||||
if not args.search or args.search.lower() in descriptor.lower():
|
||||
rows.append(
|
||||
[
|
||||
descriptor,
|
||||
model.huggingface_repo,
|
||||
f"{model.max_seq_length // 1024}K",
|
||||
]
|
||||
)
|
||||
if len(rows) == 0:
|
||||
print(f"Did not find any model matching `{args.search}`.")
|
||||
else:
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
|||
from llama_stack.cli.model.download import ModelDownload
|
||||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.remove import ModelRemove
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
@ -35,3 +36,4 @@ class ModelParser(Subcommand):
|
|||
ModelPromptFormat.create(subparsers)
|
||||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
ModelRemove.create(subparsers)
|
||||
|
|
67
llama_stack/cli/model/remove.py
Normal file
67
llama_stack/cli/model/remove.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelRemove(Subcommand):
|
||||
"""Remove the downloaded llama model"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"remove",
|
||||
prog="llama model remove",
|
||||
description="Remove the downloaded llama model",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_remove_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=True,
|
||||
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Used to forcefully remove the llama model from the storage without further confirmation",
|
||||
)
|
||||
|
||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
if args.model == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
else:
|
||||
model = resolve_model(args.model)
|
||||
|
||||
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
|
||||
|
||||
if model is None or not os.path.isdir(model_path):
|
||||
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||
return
|
||||
|
||||
if args.force:
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
print("Removal aborted.")
|
|
@ -9,6 +9,7 @@ import importlib.resources
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
@ -23,10 +24,10 @@ from termcolor import cprint
|
|||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.build import (
|
||||
SERVER_DEPENDENCIES,
|
||||
ImageType,
|
||||
build_image,
|
||||
get_provider_dependencies,
|
||||
)
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
DistributionSpec,
|
||||
|
@ -37,6 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
@ -59,8 +62,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
if args.list_templates:
|
||||
return _run_template_list_cmd()
|
||||
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
if args.image_type == "venv":
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
image_name = args.image_name or current_venv
|
||||
if not image_name and in_notebook():
|
||||
image_name = "__system__"
|
||||
elif args.image_type == "conda":
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
else:
|
||||
image_name = args.image_name
|
||||
|
||||
if args.template:
|
||||
available_templates = available_templates_specs()
|
||||
|
@ -69,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
sys.exit(1)
|
||||
build_config = available_templates[args.template]
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
|
@ -78,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
sys.exit(1)
|
||||
elif not args.config and not args.template:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
|
@ -159,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
f"Could not parse config file {args.config}: {e}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
sys.exit(1)
|
||||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
|
@ -177,19 +188,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
print(f"uv pip install {special_dep}")
|
||||
return
|
||||
|
||||
_run_stack_build_command_from_build_config(
|
||||
build_config,
|
||||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
template_name=args.template,
|
||||
)
|
||||
try:
|
||||
run_config = _run_stack_build_command_from_build_config(
|
||||
build_config,
|
||||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
template_name=args.template,
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
cprint(
|
||||
f"Error building stack: {exc}",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
if run_config is None:
|
||||
cprint(
|
||||
"Run config path is empty",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if args.run:
|
||||
run_config = Path(run_config)
|
||||
config_dict = yaml.safe_load(run_config.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||
run_with_pty(run_args)
|
||||
|
||||
|
||||
def _generate_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
) -> str:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
|
@ -239,6 +272,7 @@ def _generate_run_config(
|
|||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
return run_config_file
|
||||
|
||||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
|
@ -246,7 +280,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: Optional[str] = None,
|
||||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
|
@ -256,6 +290,9 @@ def _run_stack_build_command_from_build_config(
|
|||
elif build_config.image_type == ImageType.conda.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a venv image")
|
||||
|
||||
if template_name:
|
||||
build_dir = DISTRIBS_BASE_DIR / template_name
|
||||
|
@ -276,7 +313,7 @@ def _run_stack_build_command_from_build_config(
|
|||
template_or_config=template_name or config_path,
|
||||
)
|
||||
if return_code != 0:
|
||||
return
|
||||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if template_name:
|
||||
# copy run.yaml from template to build_dir instead of generating it again
|
||||
|
@ -286,8 +323,9 @@ def _run_stack_build_command_from_build_config(
|
|||
shutil.copy(path, run_config_file)
|
||||
|
||||
cprint("Build Successful!", color="green")
|
||||
return template_path
|
||||
else:
|
||||
_generate_run_config(build_config, build_dir, image_name)
|
||||
return _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
|
||||
def _run_template_list_cmd() -> None:
|
||||
|
|
|
@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if
|
|||
help="Print the dependencies for the stack only, without building the stack",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--run",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class StackConfigure(Subcommand):
|
||||
"""Llama cli for configuring llama toolchain configs"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"configure",
|
||||
prog="llama stack configure",
|
||||
description="Configure a llama stack distribution",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_configure_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"config",
|
||||
type=str,
|
||||
help="Path to the build config file (e.g. ~/.llama/builds/<image_type>/<name>-build.yaml). For container, this could also be the name of the container image. ",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build/<image_type>/<name>-run.yaml",
|
||||
)
|
||||
|
||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
self.parser.error(
|
||||
"""
|
||||
DEPRECATED! llama stack configure has been deprecated.
|
||||
Please use llama stack run <path/to/run.yaml> instead.
|
||||
Please see example run.yaml in /distributions folder.
|
||||
"""
|
||||
)
|
|
@ -74,10 +74,6 @@ class StackRun(Subcommand):
|
|||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -87,7 +83,7 @@ class StackRun(Subcommand):
|
|||
BUILDS_BASE_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
|
||||
if not args.config:
|
||||
self.parser.error("Must specify a config file to run")
|
||||
|
@ -125,64 +121,7 @@ class StackRun(Subcommand):
|
|||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if args.image_type == ImageType.container.value or config.container_image:
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
run_args = [script, image_name]
|
||||
elif args.image_type == ImageType.conda.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
if not image_name:
|
||||
cprint(
|
||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
def get_conda_prefix(env_name):
|
||||
# Conda "base" environment does not end with "base" in the
|
||||
# prefix, so should be handled separately.
|
||||
if env_name == "base":
|
||||
return os.environ.get("CONDA_PREFIX")
|
||||
# Get conda environments info
|
||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||
envs = conda_env_info["envs"]
|
||||
for envpath in envs:
|
||||
if envpath.endswith(env_name):
|
||||
return envpath
|
||||
return None
|
||||
|
||||
print(f"Using conda environment: {image_name}")
|
||||
conda_prefix = get_conda_prefix(image_name)
|
||||
if not conda_prefix:
|
||||
cprint(
|
||||
f"Conda environment {image_name} does not exist.",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
||||
if not build_file.exists():
|
||||
cprint(
|
||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
||||
run_args = [
|
||||
script,
|
||||
image_name,
|
||||
]
|
||||
else:
|
||||
# else must be venv since that is the only valid option left.
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
venv = args.image_name or current_venv
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
||||
run_args = [
|
||||
script,
|
||||
venv,
|
||||
]
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
||||
|
||||
run_args.extend([str(config_file), str(args.port)])
|
||||
if args.disable_ipv6:
|
||||
|
@ -206,5 +145,4 @@ class StackRun(Subcommand):
|
|||
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||
|
||||
run_with_pty(run_args)
|
||||
|
|
|
@ -10,7 +10,6 @@ from importlib.metadata import version
|
|||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
from .build import StackBuild
|
||||
from .configure import StackConfigure
|
||||
from .list_apis import StackListApis
|
||||
from .list_providers import StackListProviders
|
||||
from .run import StackRun
|
||||
|
@ -37,7 +36,6 @@ class StackParser(Subcommand):
|
|||
|
||||
# Add sub-commands
|
||||
StackBuild.create(subparsers)
|
||||
StackConfigure.create(subparsers)
|
||||
StackListApis.create(subparsers)
|
||||
StackListProviders.create(subparsers)
|
||||
StackRun.create(subparsers)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider
|
|||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [
|
|||
]
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||
|
@ -32,7 +34,7 @@ container_base="$3"
|
|||
build_file_path="$4"
|
||||
host_build_dir="$5"
|
||||
pip_dependencies="$6"
|
||||
special_pip_deps="$7"
|
||||
special_pip_deps="${7:-}"
|
||||
|
||||
|
||||
# Define color codes
|
||||
|
@ -106,26 +108,39 @@ fi
|
|||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
|
||||
install_local_package() {
|
||||
local dir="$1"
|
||||
local mount_point="$2"
|
||||
local name="$3"
|
||||
|
||||
if [ ! -d "$dir" ]; then
|
||||
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install in editable format. We will mount the source code into the container
|
||||
# so that changes will be reflected in the container without having to do a
|
||||
# rebuild. This is just for development convenience.
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
||||
add_to_container << EOF
|
||||
COPY $LLAMA_STACK_DIR $stack_mount
|
||||
COPY $dir $mount_point
|
||||
EOF
|
||||
fi
|
||||
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache -e $stack_mount
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
|
||||
else
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
|
@ -134,6 +149,7 @@ RUN uv pip install fastapi libcst
|
|||
EOF
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
EOF
|
||||
|
@ -149,23 +165,6 @@ EOF
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
||||
add_to_container << EOF
|
||||
COPY $LLAMA_MODELS_DIR $models_mount
|
||||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN uv pip uninstall llama-models
|
||||
RUN uv pip install --no-cache $models_mount
|
||||
EOF
|
||||
fi
|
||||
|
||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||
if [[ "$template_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
|
@ -177,6 +176,15 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|||
EOF
|
||||
fi
|
||||
|
||||
# Add other require item commands genearic to all containers
|
||||
add_to_container << EOF
|
||||
|
||||
# Allows running as non-root user
|
||||
RUN mkdir -p /.llama /.cache
|
||||
|
||||
RUN chmod -R g+rw /app /.llama /.cache
|
||||
EOF
|
||||
|
||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
||||
cat $TEMP_DIR/Containerfile
|
||||
printf "\n"
|
||||
|
@ -189,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
|||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
||||
fi
|
||||
fi
|
||||
|
||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||
|
|
|
@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
|||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
|
@ -24,8 +25,8 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
@ -34,8 +35,7 @@ special_pip_deps="$3"
|
|||
|
||||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
env_name="$1"
|
||||
pip_dependencies="$2"
|
||||
|
||||
# Define color codes
|
||||
|
@ -74,9 +74,13 @@ run() {
|
|||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ]; then
|
||||
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||
echo "Installing dependencies in system Python environment"
|
||||
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
|
||||
export UV_SYSTEM_PYTHON=1
|
||||
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||
echo "Virtual environment $env_name is already active"
|
||||
else
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
|
@ -90,6 +94,7 @@ run() {
|
|||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
|
|
|
@ -41,6 +41,7 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
@ -52,19 +53,6 @@ logger = logging.getLogger(__name__)
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
|
@ -230,12 +218,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
|
||||
# Redact sensitive information before printing
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
|
|
@ -6,7 +6,11 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
|
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -214,7 +220,10 @@ class InferenceRouter(Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
|
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
|
|||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do
|
|||
esac
|
||||
done
|
||||
|
||||
echo "Using virtual environment: $venv_path"
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$venv_path" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||
|
|
|
@ -12,8 +12,78 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
run_args = [script, image_name]
|
||||
elif image_type == ImageType.conda.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = image_name or current_conda_env
|
||||
if not image_name:
|
||||
cprint(
|
||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
def get_conda_prefix(env_name):
|
||||
# Conda "base" environment does not end with "base" in the
|
||||
# prefix, so should be handled separately.
|
||||
if env_name == "base":
|
||||
return os.environ.get("CONDA_PREFIX")
|
||||
# Get conda environments info
|
||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||
envs = conda_env_info["envs"]
|
||||
for envpath in envs:
|
||||
if envpath.endswith(env_name):
|
||||
return envpath
|
||||
return None
|
||||
|
||||
print(f"Using conda environment: {image_name}")
|
||||
conda_prefix = get_conda_prefix(image_name)
|
||||
if not conda_prefix:
|
||||
cprint(
|
||||
f"Conda environment {image_name} does not exist.",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
build_file = Path(conda_prefix) / "llamastack-build.yaml"
|
||||
if not build_file.exists():
|
||||
cprint(
|
||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
||||
run_args = [
|
||||
script,
|
||||
image_name,
|
||||
]
|
||||
else:
|
||||
# else must be venv since that is the only valid option left.
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
venv = image_name or current_venv
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
||||
run_args = [
|
||||
script,
|
||||
venv,
|
||||
]
|
||||
return run_args
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
if sys.platform.startswith("win"):
|
||||
|
@ -22,6 +92,19 @@ def run_with_pty(command):
|
|||
return _run_with_pty_unix(command)
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
# useful when you want to run interactive things
|
||||
def _run_with_pty_unix(command):
|
||||
|
|
13
llama_stack/distribution/utils/image_types.py
Normal file
13
llama_stack/distribution/utils/image_types.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResumeRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
InferenceStep,
|
||||
|
@ -60,9 +62,13 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for turn in turns:
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
if output_message.tool_calls and request.allow_turn_resume:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.tool_responses)
|
||||
|
||||
last_turn_messages = [
|
||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
|
||||
# get the steps from the turn id
|
||||
steps = []
|
||||
if len(turns) > 0:
|
||||
steps = turns[-1].steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
# we'll create a new tool execution step with current time
|
||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in request.tool_responses
|
||||
],
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
last_turn_start_time = datetime.now()
|
||||
if len(turns) > 0:
|
||||
last_turn_start_time = turns[-1].started_at
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=last_turn_messages,
|
||||
output_message=output_message,
|
||||
started_at=last_turn_start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
|
@ -456,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=retrieved_context or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -611,11 +743,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -625,6 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -639,6 +768,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
@ -650,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now()
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
@ -675,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
|
@ -913,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
# When this changes, we can drop this assert
|
||||
# Whether to call tools on each message and aggregate
|
||||
# or aggregate and call tool once, reamins to be seen.
|
||||
assert len(messages) == 1, "Expected single message"
|
||||
message = messages[0]
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
|
@ -946,14 +1092,7 @@ async def execute_tool_call_maybe(
|
|||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
|
@ -11,8 +11,6 @@ import tempfile
|
|||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
|
@ -21,6 +19,7 @@ from llama_stack.apis.agents import (
|
|||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
|
@ -68,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
# check if "bwrap" is available
|
||||
if not shutil.which("bwrap"):
|
||||
print(
|
||||
colored(
|
||||
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
|
@ -146,6 +140,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -155,6 +150,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
allow_turn_resume=allow_turn_resume,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
@ -169,6 +165,34 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
|||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
|
|
@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl(
|
|||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
|
@ -22,11 +22,14 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -230,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
|
||||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||
|
||||
chunks = chunks[: query_config.max_chunks]
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: query_config.max_chunks]:
|
||||
for c in chunks:
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
|
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
|
|
|
@ -4,26 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# config.py
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec.db",
|
||||
)
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
pip_packages=["sentence-transformers"],
|
||||
pip_packages=[
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
|
|
|
@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[],
|
||||
pip_packages=[
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"transformers",
|
||||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
|
|
|
@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import (
|
|||
remote_provider_spec,
|
||||
)
|
||||
|
||||
EMBEDDING_DEPS = [
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"transformers",
|
||||
# this happens to work because special dependencies are always installed last
|
||||
# so if there was a regular torch installed first, this would be ignored
|
||||
# we need a better way to do this to identify potential conflicts, etc.
|
||||
# for now, this lets us significantly reduce the size of the container which
|
||||
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
|
@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite_vec",
|
||||
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
|
||||
provider_type="inline::sqlite-vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite_vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
),
|
||||
|
@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
||||
pip_packages=["chromadb"],
|
||||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
),
|
||||
|
@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
),
|
||||
|
|
|
@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -162,7 +167,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -172,6 +177,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
|
@ -130,7 +135,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedContent],
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -204,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
@ -232,7 +236,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -17,16 +17,23 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
SamplingParams,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -140,7 +147,10 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
|
@ -46,6 +48,14 @@ _MODEL_ENTRIES = [
|
|||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="baai/bge-m3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
|
|
|
@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union
|
|||
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -18,15 +23,20 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
SamplingParams,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -117,9 +127,41 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
if any(content_has_media(content) for content in contents):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
#
|
||||
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
|
||||
# ->
|
||||
# OpenAI: input = str | List[str]
|
||||
#
|
||||
# we can ignore str and always pass List[str] to OpenAI
|
||||
#
|
||||
flat_contents = [
|
||||
item.text if isinstance(item, TextContentItem) else item
|
||||
for content in contents
|
||||
for item in (content if isinstance(content, list) else [content])
|
||||
]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
||||
)
|
||||
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
||||
# ->
|
||||
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
|
||||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
|
@ -13,6 +13,7 @@ from ollama import AsyncClient
|
|||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -20,11 +21,13 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -175,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
|
@ -184,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
llama_model,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
|
@ -258,7 +262,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -274,7 +281,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding:
|
||||
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
response = await self.client.list()
|
||||
else:
|
||||
response = await self.client.ps()
|
||||
|
@ -284,7 +294,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
return await self.register_helper.register_model(model)
|
||||
return model
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
|
|
|
@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
|
|||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
return client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
|
|
@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -119,6 +120,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -5,16 +5,38 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
TextTruncation,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
|
@ -119,7 +141,10 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -268,7 +273,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -198,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
@ -219,7 +223,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
|
|
|
@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_models.datatypes import StopReason, ToolCall
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -22,18 +28,21 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -112,10 +121,16 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
|||
if tool_param.required:
|
||||
compat_required.append(tool_key)
|
||||
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
tool_name = tool.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
compat_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.tool_name,
|
||||
"name": tool_name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
@ -369,7 +384,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -8,7 +8,10 @@ import unittest
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
|
@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_completion_message_encoding(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
|
|
|
@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture:
|
|||
providers=[
|
||||
Provider(
|
||||
provider_id="sqlite_vec",
|
||||
provider_type="inline::sqlite_vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
|
|
|
@ -5,13 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
InterleavedContent,
|
||||
EmbeddingTaskType,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
TextTruncation,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
@ -25,11 +28,16 @@ class SentenceTransformerEmbeddingMixin:
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
embeddings = embedding_model.encode(
|
||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
|
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
if model.metadata.get("llama_model") is None:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||
)
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
if llama_model is None:
|
||||
return model
|
||||
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != model.metadata["llama_model"]:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
|
@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info(
|
|||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
len(model_input.tokens),
|
||||
|
|
|
@ -5,12 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
from typing import Any, Callable, List, Optional, Protocol, TypeVar
|
||||
|
||||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebMethod:
|
||||
|
@ -22,6 +20,13 @@ class WebMethod:
|
|||
raw_bytes_request_body: Optional[bool] = False
|
||||
|
||||
|
||||
class HasWebMethod(Protocol):
|
||||
__webmethod__: WebMethod
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
|
||||
|
||||
|
||||
def webmethod(
|
||||
route: Optional[str] = None,
|
||||
method: Optional[str] = None,
|
||||
|
|
|
@ -11,7 +11,7 @@ import subprocess
|
|||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterable
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
|
@ -39,7 +39,7 @@ class ChangedPathTracker:
|
|||
return self._changed_paths
|
||||
|
||||
|
||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
||||
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
|
||||
"""Find immediate subdirectories in the templates folder."""
|
||||
if not templates_dir.exists():
|
||||
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
||||
|
@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
|||
return has_changes
|
||||
|
||||
|
||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
||||
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
|
||||
try:
|
||||
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||
module = importlib.import_module(module_name)
|
||||
|
|
|
@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
|
|||
pytest_args,
|
||||
"-s",
|
||||
"-v",
|
||||
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
|
||||
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::cerebras
|
||||
- inline::sentence-transformers
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
vector_io:
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::cerebras"],
|
||||
"inference": ["remote::cerebras", "inline::sentence-transformers"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
|
@ -5,6 +5,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::tgi
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::tgi"],
|
||||
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::fireworks
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::fireworks"],
|
||||
"inference": ["remote::fireworks", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::hf::serverless
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::hf::serverless"],
|
||||
"inference": ["remote::hf::serverless", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
|
@ -55,9 +55,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id="nvidia",
|
||||
model_type=m.model_type,
|
||||
metadata=m.metadata,
|
||||
)
|
||||
for m in _MODEL_ENTRIES
|
||||
]
|
||||
|
|
|
@ -135,6 +135,13 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 8192
|
||||
model_id: baai/bge-m3
|
||||
provider_id: nvidia
|
||||
provider_model_id: baai/bge-m3
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
@ -5,7 +5,7 @@ distribution_spec:
|
|||
inference:
|
||||
- remote::ollama
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- inline::sqlite-vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
|
|
|
@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::ollama"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
|
@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig.sample_run_config(),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider_faiss = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||
)
|
||||
vector_io_provider_sqlite = Provider(
|
||||
provider_id="sqlite_vec",
|
||||
provider_type="inline::sqlite_vec",
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||
)
|
||||
|
||||
|
@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_models=[inference_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [
|
||||
inference_provider,
|
||||
embedding_provider,
|
||||
],
|
||||
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_sqlite],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
|
|
@ -16,24 +16,11 @@ providers:
|
|||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -16,24 +16,11 @@ providers:
|
|||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
- provider_id: sqlite_vec
|
||||
provider_type: inline::sqlite_vec
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -100,11 +87,14 @@ models:
|
|||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
model_type: llm
|
||||
<<<<<<< HEAD
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
=======
|
||||
>>>>>>> upstream/main
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::vllm
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::vllm"],
|
||||
"inference": ["remote::vllm", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::tgi
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::tgi"],
|
||||
"inference": ["remote::tgi", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue