mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Merge branch 'main' into add-mcp-authentication-param
This commit is contained in:
commit
8632c705aa
1250 changed files with 2278 additions and 343484 deletions
|
|
@ -5,7 +5,7 @@ omit =
|
||||||
*/llama_stack/templates/*
|
*/llama_stack/templates/*
|
||||||
.venv/*
|
.venv/*
|
||||||
*/llama_stack/cli/scripts/*
|
*/llama_stack/cli/scripts/*
|
||||||
*/llama_stack/ui/*
|
*/llama_stack_ui/*
|
||||||
*/llama_stack/distribution/ui/*
|
*/llama_stack/distribution/ui/*
|
||||||
*/llama_stack/strong_typing/*
|
*/llama_stack/strong_typing/*
|
||||||
*/llama_stack/env.py
|
*/llama_stack/env.py
|
||||||
|
|
|
||||||
2
.github/dependabot.yml
vendored
2
.github/dependabot.yml
vendored
|
|
@ -22,7 +22,7 @@ updates:
|
||||||
prefix: chore(python-deps)
|
prefix: chore(python-deps)
|
||||||
|
|
||||||
- package-ecosystem: npm
|
- package-ecosystem: npm
|
||||||
directory: "/llama_stack/ui"
|
directory: "/llama_stack_ui"
|
||||||
schedule:
|
schedule:
|
||||||
interval: "weekly"
|
interval: "weekly"
|
||||||
day: "saturday"
|
day: "saturday"
|
||||||
|
|
|
||||||
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -14,7 +14,7 @@ on:
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
- 'distributions/**'
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack_ui/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/integration/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
|
||||||
3
.github/workflows/integration-tests.yml
vendored
3
.github/workflows/integration-tests.yml
vendored
|
|
@ -14,7 +14,7 @@ on:
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack_ui/**'
|
||||||
- 'tests/**'
|
- 'tests/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
@ -22,6 +22,7 @@ on:
|
||||||
- '.github/actions/setup-ollama/action.yml'
|
- '.github/actions/setup-ollama/action.yml'
|
||||||
- '.github/actions/setup-test-environment/action.yml'
|
- '.github/actions/setup-test-environment/action.yml'
|
||||||
- '.github/actions/run-and-record-tests/action.yml'
|
- '.github/actions/run-and-record-tests/action.yml'
|
||||||
|
- 'scripts/integration-tests.sh'
|
||||||
schedule:
|
schedule:
|
||||||
# If changing the cron schedule, update the provider in the test-matrix job
|
# If changing the cron schedule, update the provider in the test-matrix job
|
||||||
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ on:
|
||||||
- 'release-[0-9]+.[0-9]+.x'
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack_ui/**'
|
||||||
- 'tests/integration/vector_io/**'
|
- 'tests/integration/vector_io/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
|
||||||
4
.github/workflows/pre-commit.yml
vendored
4
.github/workflows/pre-commit.yml
vendored
|
|
@ -43,14 +43,14 @@ jobs:
|
||||||
with:
|
with:
|
||||||
node-version: '20'
|
node-version: '20'
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: 'src/llama_stack/ui/'
|
cache-dependency-path: 'src/llama_stack_ui/'
|
||||||
|
|
||||||
- name: Set up uv
|
- name: Set up uv
|
||||||
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||||
|
|
||||||
- name: Install npm dependencies
|
- name: Install npm dependencies
|
||||||
run: npm ci
|
run: npm ci
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack_ui
|
||||||
|
|
||||||
- name: Install pre-commit
|
- name: Install pre-commit
|
||||||
run: python -m pip install pre-commit
|
run: python -m pip install pre-commit
|
||||||
|
|
|
||||||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
|
@ -10,7 +10,7 @@ on:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'src/llama_stack/ui/**'
|
- 'src/llama_stack_ui/**'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|
|
||||||
2
.github/workflows/test-external.yml
vendored
2
.github/workflows/test-external.yml
vendored
|
|
@ -9,7 +9,7 @@ on:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack_ui/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/integration/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
|
||||||
12
.github/workflows/ui-unit-tests.yml
vendored
12
.github/workflows/ui-unit-tests.yml
vendored
|
|
@ -8,7 +8,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/ui/**'
|
- 'src/llama_stack_ui/**'
|
||||||
- '.github/workflows/ui-unit-tests.yml' # This workflow
|
- '.github/workflows/ui-unit-tests.yml' # This workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|
@ -33,22 +33,22 @@ jobs:
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: 'src/llama_stack/ui/package-lock.json'
|
cache-dependency-path: 'src/llama_stack_ui/package-lock.json'
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack_ui
|
||||||
run: npm ci
|
run: npm ci
|
||||||
|
|
||||||
- name: Run linting
|
- name: Run linting
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack_ui
|
||||||
run: npm run lint
|
run: npm run lint
|
||||||
|
|
||||||
- name: Run format check
|
- name: Run format check
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack_ui
|
||||||
run: npm run format:check
|
run: npm run format:check
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack_ui
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
|
|
||||||
|
|
|
||||||
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
|
@ -13,7 +13,7 @@ on:
|
||||||
- 'release-[0-9]+.[0-9]+.x'
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack_ui/**'
|
||||||
- 'tests/unit/**'
|
- 'tests/unit/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ repos:
|
||||||
name: Format & Lint UI
|
name: Format & Lint UI
|
||||||
entry: bash ./scripts/run-ui-linter.sh
|
entry: bash ./scripts/run-ui-linter.sh
|
||||||
language: system
|
language: system
|
||||||
files: ^src/llama_stack/ui/.*\.(ts|tsx)$
|
files: ^src/llama_stack_ui/.*\.(ts|tsx)$
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
||||||
|
|
|
||||||
525
client-sdks/stainless/config-not-source-of-truth-yet.yml
Normal file
525
client-sdks/stainless/config-not-source-of-truth-yet.yml
Normal file
|
|
@ -0,0 +1,525 @@
|
||||||
|
# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json
|
||||||
|
|
||||||
|
organization:
|
||||||
|
# Name of your organization or company, used to determine the name of the client
|
||||||
|
# and headings.
|
||||||
|
name: llama-stack-client
|
||||||
|
docs: https://llama-stack.readthedocs.io/en/latest/
|
||||||
|
contact: llamastack@meta.com
|
||||||
|
security:
|
||||||
|
- {}
|
||||||
|
- BearerAuth: []
|
||||||
|
security_schemes:
|
||||||
|
BearerAuth:
|
||||||
|
type: http
|
||||||
|
scheme: bearer
|
||||||
|
# `targets` define the output targets and their customization options, such as
|
||||||
|
# whether to emit the Node SDK and what it's package name should be.
|
||||||
|
targets:
|
||||||
|
node:
|
||||||
|
package_name: llama-stack-client
|
||||||
|
production_repo: llamastack/llama-stack-client-typescript
|
||||||
|
publish:
|
||||||
|
npm: false
|
||||||
|
python:
|
||||||
|
package_name: llama_stack_client
|
||||||
|
production_repo: llamastack/llama-stack-client-python
|
||||||
|
options:
|
||||||
|
use_uv: true
|
||||||
|
publish:
|
||||||
|
pypi: true
|
||||||
|
project_name: llama_stack_client
|
||||||
|
kotlin:
|
||||||
|
reverse_domain: com.llama_stack_client.api
|
||||||
|
production_repo: null
|
||||||
|
publish:
|
||||||
|
maven: false
|
||||||
|
go:
|
||||||
|
package_name: llama-stack-client
|
||||||
|
production_repo: llamastack/llama-stack-client-go
|
||||||
|
options:
|
||||||
|
enable_v2: true
|
||||||
|
back_compat_use_shared_package: false
|
||||||
|
|
||||||
|
# `client_settings` define settings for the API client, such as extra constructor
|
||||||
|
# arguments (used for authentication), retry behavior, idempotency, etc.
|
||||||
|
client_settings:
|
||||||
|
default_env_prefix: LLAMA_STACK_CLIENT
|
||||||
|
opts:
|
||||||
|
api_key:
|
||||||
|
type: string
|
||||||
|
read_env: LLAMA_STACK_CLIENT_API_KEY
|
||||||
|
auth: { security_scheme: BearerAuth }
|
||||||
|
nullable: true
|
||||||
|
|
||||||
|
# `environments` are a map of the name of the environment (e.g. "sandbox",
|
||||||
|
# "production") to the corresponding url to use.
|
||||||
|
environments:
|
||||||
|
production: http://any-hosted-llama-stack.com
|
||||||
|
|
||||||
|
# `pagination` defines [pagination schemes] which provides a template to match
|
||||||
|
# endpoints and generate next-page and auto-pagination helpers in the SDKs.
|
||||||
|
pagination:
|
||||||
|
- name: datasets_iterrows
|
||||||
|
type: offset
|
||||||
|
request:
|
||||||
|
dataset_id:
|
||||||
|
type: string
|
||||||
|
start_index:
|
||||||
|
type: integer
|
||||||
|
x-stainless-pagination-property:
|
||||||
|
purpose: offset_count_param
|
||||||
|
limit:
|
||||||
|
type: integer
|
||||||
|
response:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
next_index:
|
||||||
|
type: integer
|
||||||
|
x-stainless-pagination-property:
|
||||||
|
purpose: offset_count_start_field
|
||||||
|
- name: openai_cursor_page
|
||||||
|
type: cursor
|
||||||
|
request:
|
||||||
|
limit:
|
||||||
|
type: integer
|
||||||
|
after:
|
||||||
|
type: string
|
||||||
|
x-stainless-pagination-property:
|
||||||
|
purpose: next_cursor_param
|
||||||
|
response:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items: {}
|
||||||
|
has_more:
|
||||||
|
type: boolean
|
||||||
|
last_id:
|
||||||
|
type: string
|
||||||
|
x-stainless-pagination-property:
|
||||||
|
purpose: next_cursor_field
|
||||||
|
# `resources` define the structure and organziation for your API, such as how
|
||||||
|
# methods and models are grouped together and accessed. See the [configuration
|
||||||
|
# guide] for more information.
|
||||||
|
#
|
||||||
|
# [configuration guide]:
|
||||||
|
# https://app.stainlessapi.com/docs/guides/configure#resources
|
||||||
|
resources:
|
||||||
|
$shared:
|
||||||
|
models:
|
||||||
|
interleaved_content_item: InterleavedContentItem
|
||||||
|
interleaved_content: InterleavedContent
|
||||||
|
param_type: ParamType
|
||||||
|
safety_violation: SafetyViolation
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
scoring_result: ScoringResult
|
||||||
|
system_message: SystemMessage
|
||||||
|
query_result: RAGQueryResult
|
||||||
|
document: RAGDocument
|
||||||
|
query_config: RAGQueryConfig
|
||||||
|
toolgroups:
|
||||||
|
models:
|
||||||
|
tool_group: ToolGroup
|
||||||
|
list_tool_groups_response: ListToolGroupsResponse
|
||||||
|
methods:
|
||||||
|
register: post /v1/toolgroups
|
||||||
|
get: get /v1/toolgroups/{toolgroup_id}
|
||||||
|
list: get /v1/toolgroups
|
||||||
|
unregister: delete /v1/toolgroups/{toolgroup_id}
|
||||||
|
tools:
|
||||||
|
methods:
|
||||||
|
get: get /v1/tools/{tool_name}
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/tools
|
||||||
|
paginated: false
|
||||||
|
|
||||||
|
tool_runtime:
|
||||||
|
models:
|
||||||
|
tool_def: ToolDef
|
||||||
|
tool_invocation_result: ToolInvocationResult
|
||||||
|
methods:
|
||||||
|
list_tools:
|
||||||
|
endpoint: get /v1/tool-runtime/list-tools
|
||||||
|
paginated: false
|
||||||
|
invoke_tool: post /v1/tool-runtime/invoke
|
||||||
|
subresources:
|
||||||
|
rag_tool:
|
||||||
|
methods:
|
||||||
|
insert: post /v1/tool-runtime/rag-tool/insert
|
||||||
|
query: post /v1/tool-runtime/rag-tool/query
|
||||||
|
|
||||||
|
responses:
|
||||||
|
models:
|
||||||
|
response_object_stream: OpenAIResponseObjectStream
|
||||||
|
response_object: OpenAIResponseObject
|
||||||
|
methods:
|
||||||
|
create:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/responses
|
||||||
|
streaming:
|
||||||
|
stream_event_model: responses.response_object_stream
|
||||||
|
param_discriminator: stream
|
||||||
|
retrieve: get /v1/responses/{response_id}
|
||||||
|
list:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/responses
|
||||||
|
delete:
|
||||||
|
type: http
|
||||||
|
endpoint: delete /v1/responses/{response_id}
|
||||||
|
subresources:
|
||||||
|
input_items:
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/responses/{response_id}/input_items
|
||||||
|
|
||||||
|
prompts:
|
||||||
|
models:
|
||||||
|
prompt: Prompt
|
||||||
|
list_prompts_response: ListPromptsResponse
|
||||||
|
methods:
|
||||||
|
create: post /v1/prompts
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/prompts
|
||||||
|
paginated: false
|
||||||
|
retrieve: get /v1/prompts/{prompt_id}
|
||||||
|
update: post /v1/prompts/{prompt_id}
|
||||||
|
delete: delete /v1/prompts/{prompt_id}
|
||||||
|
set_default_version: post /v1/prompts/{prompt_id}/set-default-version
|
||||||
|
subresources:
|
||||||
|
versions:
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/prompts/{prompt_id}/versions
|
||||||
|
paginated: false
|
||||||
|
|
||||||
|
conversations:
|
||||||
|
models:
|
||||||
|
conversation_object: Conversation
|
||||||
|
methods:
|
||||||
|
create:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/conversations
|
||||||
|
retrieve: get /v1/conversations/{conversation_id}
|
||||||
|
update:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/conversations/{conversation_id}
|
||||||
|
delete:
|
||||||
|
type: http
|
||||||
|
endpoint: delete /v1/conversations/{conversation_id}
|
||||||
|
subresources:
|
||||||
|
items:
|
||||||
|
methods:
|
||||||
|
get:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/conversations/{conversation_id}/items/{item_id}
|
||||||
|
list:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/conversations/{conversation_id}/items
|
||||||
|
create:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/conversations/{conversation_id}/items
|
||||||
|
|
||||||
|
inspect:
|
||||||
|
models:
|
||||||
|
healthInfo: HealthInfo
|
||||||
|
providerInfo: ProviderInfo
|
||||||
|
routeInfo: RouteInfo
|
||||||
|
versionInfo: VersionInfo
|
||||||
|
methods:
|
||||||
|
health: get /v1/health
|
||||||
|
version: get /v1/version
|
||||||
|
|
||||||
|
embeddings:
|
||||||
|
models:
|
||||||
|
create_embeddings_response: OpenAIEmbeddingsResponse
|
||||||
|
methods:
|
||||||
|
create: post /v1/embeddings
|
||||||
|
|
||||||
|
chat:
|
||||||
|
models:
|
||||||
|
chat_completion_chunk: OpenAIChatCompletionChunk
|
||||||
|
subresources:
|
||||||
|
completions:
|
||||||
|
methods:
|
||||||
|
create:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/chat/completions
|
||||||
|
streaming:
|
||||||
|
stream_event_model: chat.chat_completion_chunk
|
||||||
|
param_discriminator: stream
|
||||||
|
list:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/chat/completions
|
||||||
|
retrieve:
|
||||||
|
type: http
|
||||||
|
endpoint: get /v1/chat/completions/{completion_id}
|
||||||
|
completions:
|
||||||
|
methods:
|
||||||
|
create:
|
||||||
|
type: http
|
||||||
|
endpoint: post /v1/completions
|
||||||
|
streaming:
|
||||||
|
param_discriminator: stream
|
||||||
|
|
||||||
|
vector_io:
|
||||||
|
models:
|
||||||
|
queryChunksResponse: QueryChunksResponse
|
||||||
|
methods:
|
||||||
|
insert: post /v1/vector-io/insert
|
||||||
|
query: post /v1/vector-io/query
|
||||||
|
|
||||||
|
vector_stores:
|
||||||
|
models:
|
||||||
|
vector_store: VectorStoreObject
|
||||||
|
list_vector_stores_response: VectorStoreListResponse
|
||||||
|
vector_store_delete_response: VectorStoreDeleteResponse
|
||||||
|
vector_store_search_response: VectorStoreSearchResponsePage
|
||||||
|
methods:
|
||||||
|
create: post /v1/vector_stores
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/vector_stores
|
||||||
|
retrieve: get /v1/vector_stores/{vector_store_id}
|
||||||
|
update: post /v1/vector_stores/{vector_store_id}
|
||||||
|
delete: delete /v1/vector_stores/{vector_store_id}
|
||||||
|
search: post /v1/vector_stores/{vector_store_id}/search
|
||||||
|
subresources:
|
||||||
|
files:
|
||||||
|
models:
|
||||||
|
vector_store_file: VectorStoreFileObject
|
||||||
|
methods:
|
||||||
|
list: get /v1/vector_stores/{vector_store_id}/files
|
||||||
|
retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||||
|
update: post /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||||
|
delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||||
|
create: post /v1/vector_stores/{vector_store_id}/files
|
||||||
|
content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content
|
||||||
|
file_batches:
|
||||||
|
models:
|
||||||
|
vector_store_file_batches: VectorStoreFileBatchObject
|
||||||
|
list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse
|
||||||
|
methods:
|
||||||
|
create: post /v1/vector_stores/{vector_store_id}/file_batches
|
||||||
|
retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
|
||||||
|
list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
|
||||||
|
cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
|
||||||
|
|
||||||
|
models:
|
||||||
|
models:
|
||||||
|
model: OpenAIModel
|
||||||
|
list_models_response: OpenAIListModelsResponse
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/models
|
||||||
|
paginated: false
|
||||||
|
retrieve: get /v1/models/{model_id}
|
||||||
|
register: post /v1/models
|
||||||
|
unregister: delete /v1/models/{model_id}
|
||||||
|
subresources:
|
||||||
|
openai:
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/models
|
||||||
|
paginated: false
|
||||||
|
|
||||||
|
providers:
|
||||||
|
models:
|
||||||
|
list_providers_response: ListProvidersResponse
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/providers
|
||||||
|
paginated: false
|
||||||
|
retrieve: get /v1/providers/{provider_id}
|
||||||
|
|
||||||
|
routes:
|
||||||
|
models:
|
||||||
|
list_routes_response: ListRoutesResponse
|
||||||
|
methods:
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/inspect/routes
|
||||||
|
paginated: false
|
||||||
|
|
||||||
|
|
||||||
|
moderations:
|
||||||
|
models:
|
||||||
|
create_response: ModerationObject
|
||||||
|
methods:
|
||||||
|
create: post /v1/moderations
|
||||||
|
|
||||||
|
|
||||||
|
safety:
|
||||||
|
models:
|
||||||
|
run_shield_response: RunShieldResponse
|
||||||
|
methods:
|
||||||
|
run_shield: post /v1/safety/run-shield
|
||||||
|
|
||||||
|
|
||||||
|
shields:
|
||||||
|
models:
|
||||||
|
shield: Shield
|
||||||
|
list_shields_response: ListShieldsResponse
|
||||||
|
methods:
|
||||||
|
retrieve: get /v1/shields/{identifier}
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/shields
|
||||||
|
paginated: false
|
||||||
|
register: post /v1/shields
|
||||||
|
delete: delete /v1/shields/{identifier}
|
||||||
|
|
||||||
|
scoring:
|
||||||
|
methods:
|
||||||
|
score: post /v1/scoring/score
|
||||||
|
score_batch: post /v1/scoring/score-batch
|
||||||
|
scoring_functions:
|
||||||
|
methods:
|
||||||
|
retrieve: get /v1/scoring-functions/{scoring_fn_id}
|
||||||
|
list:
|
||||||
|
endpoint: get /v1/scoring-functions
|
||||||
|
paginated: false
|
||||||
|
register: post /v1/scoring-functions
|
||||||
|
models:
|
||||||
|
scoring_fn: ScoringFn
|
||||||
|
scoring_fn_params: ScoringFnParams
|
||||||
|
list_scoring_functions_response: ListScoringFunctionsResponse
|
||||||
|
|
||||||
|
files:
|
||||||
|
methods:
|
||||||
|
create: post /v1/files
|
||||||
|
list: get /v1/files
|
||||||
|
retrieve: get /v1/files/{file_id}
|
||||||
|
delete: delete /v1/files/{file_id}
|
||||||
|
content: get /v1/files/{file_id}/content
|
||||||
|
models:
|
||||||
|
file: OpenAIFileObject
|
||||||
|
list_files_response: ListOpenAIFileResponse
|
||||||
|
delete_file_response: OpenAIFileDeleteResponse
|
||||||
|
|
||||||
|
alpha:
|
||||||
|
subresources:
|
||||||
|
inference:
|
||||||
|
methods:
|
||||||
|
rerank: post /v1alpha/inference/rerank
|
||||||
|
|
||||||
|
post_training:
|
||||||
|
models:
|
||||||
|
algorithm_config: AlgorithmConfig
|
||||||
|
post_training_job: PostTrainingJob
|
||||||
|
list_post_training_jobs_response: ListPostTrainingJobsResponse
|
||||||
|
methods:
|
||||||
|
preference_optimize: post /v1alpha/post-training/preference-optimize
|
||||||
|
supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune
|
||||||
|
subresources:
|
||||||
|
job:
|
||||||
|
methods:
|
||||||
|
artifacts: get /v1alpha/post-training/job/artifacts
|
||||||
|
cancel: post /v1alpha/post-training/job/cancel
|
||||||
|
status: get /v1alpha/post-training/job/status
|
||||||
|
list:
|
||||||
|
endpoint: get /v1alpha/post-training/jobs
|
||||||
|
paginated: false
|
||||||
|
|
||||||
|
benchmarks:
|
||||||
|
methods:
|
||||||
|
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}
|
||||||
|
list:
|
||||||
|
endpoint: get /v1alpha/eval/benchmarks
|
||||||
|
paginated: false
|
||||||
|
register: post /v1alpha/eval/benchmarks
|
||||||
|
models:
|
||||||
|
benchmark: Benchmark
|
||||||
|
list_benchmarks_response: ListBenchmarksResponse
|
||||||
|
|
||||||
|
eval:
|
||||||
|
methods:
|
||||||
|
evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
||||||
|
run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
||||||
|
evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
||||||
|
run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
||||||
|
|
||||||
|
subresources:
|
||||||
|
jobs:
|
||||||
|
methods:
|
||||||
|
cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||||
|
status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||||
|
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result
|
||||||
|
models:
|
||||||
|
evaluate_response: EvaluateResponse
|
||||||
|
benchmark_config: BenchmarkConfig
|
||||||
|
job: Job
|
||||||
|
|
||||||
|
beta:
|
||||||
|
subresources:
|
||||||
|
datasets:
|
||||||
|
models:
|
||||||
|
list_datasets_response: ListDatasetsResponse
|
||||||
|
methods:
|
||||||
|
register: post /v1beta/datasets
|
||||||
|
retrieve: get /v1beta/datasets/{dataset_id}
|
||||||
|
list:
|
||||||
|
endpoint: get /v1beta/datasets
|
||||||
|
paginated: false
|
||||||
|
unregister: delete /v1beta/datasets/{dataset_id}
|
||||||
|
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
|
||||||
|
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
|
||||||
|
|
||||||
|
|
||||||
|
settings:
|
||||||
|
license: MIT
|
||||||
|
unwrap_response_fields: [ data ]
|
||||||
|
|
||||||
|
openapi:
|
||||||
|
transformations:
|
||||||
|
- command: mergeObject
|
||||||
|
reason: Better return_type using enum
|
||||||
|
args:
|
||||||
|
target:
|
||||||
|
- '$.components.schemas'
|
||||||
|
object:
|
||||||
|
ReturnType:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
enum:
|
||||||
|
- string
|
||||||
|
- number
|
||||||
|
- boolean
|
||||||
|
- array
|
||||||
|
- object
|
||||||
|
- json
|
||||||
|
- union
|
||||||
|
- chat_completion_input
|
||||||
|
- completion_input
|
||||||
|
- agent_turn_input
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
|
- command: replaceProperties
|
||||||
|
reason: Replace return type properties with better model (see above)
|
||||||
|
args:
|
||||||
|
filter:
|
||||||
|
only:
|
||||||
|
- '$.components.schemas.ScoringFn.properties.return_type'
|
||||||
|
- '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type'
|
||||||
|
value:
|
||||||
|
$ref: '#/components/schemas/ReturnType'
|
||||||
|
- command: oneOfToAnyOf
|
||||||
|
reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants
|
||||||
|
|
||||||
|
# `readme` is used to configure the code snippets that will be rendered in the
|
||||||
|
# README.md of various SDKs. In particular, you can change the `headline`
|
||||||
|
# snippet's endpoint and the arguments to call it with.
|
||||||
|
readme:
|
||||||
|
example_requests:
|
||||||
|
default:
|
||||||
|
type: request
|
||||||
|
endpoint: post /v1/chat/completions
|
||||||
|
params: &ref_0 {}
|
||||||
|
headline:
|
||||||
|
type: request
|
||||||
|
endpoint: post /v1/models
|
||||||
|
params: *ref_0
|
||||||
|
pagination:
|
||||||
|
type: request
|
||||||
|
endpoint: post /v1/chat/completions
|
||||||
|
params: {}
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -23,5 +23,4 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
||||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||||
- **Batch Inference**: run inference on a dataset of inputs
|
- **Batch Inference**: run inference on a dataset of inputs
|
||||||
- **Batch Agents**: run agents on a dataset of inputs
|
- **Batch Agents**: run agents on a dataset of inputs
|
||||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
|
||||||
- **Batches**: OpenAI-compatible batch management for inference
|
- **Batches**: OpenAI-compatible batch management for inference
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ spec:
|
||||||
|
|
||||||
# Navigate to the UI directory
|
# Navigate to the UI directory
|
||||||
echo "Navigating to UI directory..."
|
echo "Navigating to UI directory..."
|
||||||
cd /app/llama_stack/ui
|
cd /app/llama_stack_ui
|
||||||
|
|
||||||
# Check if package.json exists
|
# Check if package.json exists
|
||||||
if [ ! -f "package.json" ]; then
|
if [ ! -f "package.json" ]; then
|
||||||
|
|
|
||||||
|
|
@ -239,8 +239,13 @@ client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
|
|
||||||
# Select the first LLM
|
# Select the first LLM
|
||||||
llm = next(m for m in models if m.model_type == "llm" and m.provider_id == "ollama")
|
llm = next(
|
||||||
model_id = llm.identifier
|
m for m in models
|
||||||
|
if m.custom_metadata
|
||||||
|
and m.custom_metadata.get("model_type") == "llm"
|
||||||
|
and m.custom_metadata.get("provider_id") == "ollama"
|
||||||
|
)
|
||||||
|
model_id = llm.id
|
||||||
|
|
||||||
print("Model:", model_id)
|
print("Model:", model_id)
|
||||||
|
|
||||||
|
|
@ -279,8 +284,13 @@ import uuid
|
||||||
client = LlamaStackClient(base_url=f"http://localhost:8321")
|
client = LlamaStackClient(base_url=f"http://localhost:8321")
|
||||||
|
|
||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
llm = next(m for m in models if m.model_type == "llm" and m.provider_id == "ollama")
|
llm = next(
|
||||||
model_id = llm.identifier
|
m for m in models
|
||||||
|
if m.custom_metadata
|
||||||
|
and m.custom_metadata.get("model_type") == "llm"
|
||||||
|
and m.custom_metadata.get("provider_id") == "ollama"
|
||||||
|
)
|
||||||
|
model_id = llm.id
|
||||||
|
|
||||||
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
|
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
|
||||||
|
|
||||||
|
|
@ -450,8 +460,11 @@ import uuid
|
||||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
# Create a vector database instance
|
# Create a vector database instance
|
||||||
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
embed_lm = next(
|
||||||
embedding_model = embed_lm.identifier
|
m for m in client.models.list()
|
||||||
|
if m.custom_metadata and m.custom_metadata.get("model_type") == "embedding"
|
||||||
|
)
|
||||||
|
embedding_model = embed_lm.id
|
||||||
vector_db_id = f"v{uuid.uuid4().hex}"
|
vector_db_id = f"v{uuid.uuid4().hex}"
|
||||||
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
||||||
# We capture the correct ID from the response's .identifier attribute.
|
# We capture the correct ID from the response's .identifier attribute.
|
||||||
|
|
@ -489,9 +502,11 @@ client.tool_runtime.rag_tool.insert(
|
||||||
llm = next(
|
llm = next(
|
||||||
m
|
m
|
||||||
for m in client.models.list()
|
for m in client.models.list()
|
||||||
if m.model_type == "llm" and m.provider_id == "ollama"
|
if m.custom_metadata
|
||||||
|
and m.custom_metadata.get("model_type") == "llm"
|
||||||
|
and m.custom_metadata.get("provider_id") == "ollama"
|
||||||
)
|
)
|
||||||
model = llm.identifier
|
model = llm.id
|
||||||
|
|
||||||
# Create the RAG agent
|
# Create the RAG agent
|
||||||
rag_agent = Agent(
|
rag_agent = Agent(
|
||||||
|
|
|
||||||
|
|
@ -196,16 +196,10 @@ def _get_endpoint_functions(
|
||||||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||||
|
|
||||||
# This import must be dynamic here
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
|
|
||||||
|
|
||||||
# iterate in reverse member resolution order to find most specific class first
|
# iterate in reverse member resolution order to find most specific class first
|
||||||
for cls in reversed(inspect.getmro(derived_cls)):
|
for cls in reversed(inspect.getmro(derived_cls)):
|
||||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||||
if name == member_fn:
|
if name == member_fn:
|
||||||
# HACK ALERT
|
|
||||||
if cls == RAGToolRuntime:
|
|
||||||
return ToolRuntime
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
|
|
|
||||||
10706
docs/static/deprecated-llama-stack-spec.yaml
vendored
10706
docs/static/deprecated-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
2164
docs/static/experimental-llama-stack-spec.yaml
vendored
2164
docs/static/experimental-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
14048
docs/static/llama-stack-spec.html
vendored
14048
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
796
docs/static/llama-stack-spec.yaml
vendored
796
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -974,11 +974,11 @@ paths:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A ListModelsResponse.
|
description: A OpenAIListModelsResponse.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ListModelsResponse'
|
$ref: '#/components/schemas/OpenAIListModelsResponse'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
'429':
|
'429':
|
||||||
|
|
@ -991,8 +991,8 @@ paths:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Models
|
||||||
summary: List all models.
|
summary: List models using the OpenAI API.
|
||||||
description: List all models.
|
description: List models using the OpenAI API.
|
||||||
parameters: []
|
parameters: []
|
||||||
deprecated: false
|
deprecated: false
|
||||||
post:
|
post:
|
||||||
|
|
@ -1982,40 +1982,6 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/synthetic-data-generation/generate:
|
|
||||||
post:
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: >-
|
|
||||||
Response containing filtered synthetic data samples and optional statistics
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/SyntheticDataGenerationResponse'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest400'
|
|
||||||
'429':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/TooManyRequests429
|
|
||||||
'500':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/InternalServerError500
|
|
||||||
default:
|
|
||||||
$ref: '#/components/responses/DefaultError'
|
|
||||||
tags:
|
|
||||||
- SyntheticDataGeneration (Coming Soon)
|
|
||||||
summary: >-
|
|
||||||
Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
description: >-
|
|
||||||
Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
parameters: []
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
|
||||||
required: true
|
|
||||||
deprecated: false
|
|
||||||
/v1/tool-runtime/invoke:
|
/v1/tool-runtime/invoke:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -2086,69 +2052,6 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/URL'
|
$ref: '#/components/schemas/URL'
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/tool-runtime/rag-tool/insert:
|
|
||||||
post:
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: OK
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest400'
|
|
||||||
'429':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/TooManyRequests429
|
|
||||||
'500':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/InternalServerError500
|
|
||||||
default:
|
|
||||||
$ref: '#/components/responses/DefaultError'
|
|
||||||
tags:
|
|
||||||
- ToolRuntime
|
|
||||||
summary: >-
|
|
||||||
Index documents so they can be used by the RAG system.
|
|
||||||
description: >-
|
|
||||||
Index documents so they can be used by the RAG system.
|
|
||||||
parameters: []
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/InsertRequest'
|
|
||||||
required: true
|
|
||||||
deprecated: false
|
|
||||||
/v1/tool-runtime/rag-tool/query:
|
|
||||||
post:
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: >-
|
|
||||||
RAGQueryResult containing the retrieved content and metadata
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/RAGQueryResult'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest400'
|
|
||||||
'429':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/TooManyRequests429
|
|
||||||
'500':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/InternalServerError500
|
|
||||||
default:
|
|
||||||
$ref: '#/components/responses/DefaultError'
|
|
||||||
tags:
|
|
||||||
- ToolRuntime
|
|
||||||
summary: >-
|
|
||||||
Query the RAG system for context; typically invoked by the agent.
|
|
||||||
description: >-
|
|
||||||
Query the RAG system for context; typically invoked by the agent.
|
|
||||||
parameters: []
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/QueryRequest'
|
|
||||||
required: true
|
|
||||||
deprecated: false
|
|
||||||
/v1/toolgroups:
|
/v1/toolgroups:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -5619,6 +5522,88 @@ components:
|
||||||
title: ListRoutesResponse
|
title: ListRoutesResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response containing a list of all available API routes.
|
Response containing a list of all available API routes.
|
||||||
|
OpenAIModel:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: model
|
||||||
|
default: model
|
||||||
|
created:
|
||||||
|
type: integer
|
||||||
|
owned_by:
|
||||||
|
type: string
|
||||||
|
custom_metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- object
|
||||||
|
- created
|
||||||
|
- owned_by
|
||||||
|
title: OpenAIModel
|
||||||
|
description: A model from OpenAI.
|
||||||
|
OpenAIListModelsResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIModel'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: OpenAIListModelsResponse
|
||||||
|
ModelType:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- llm
|
||||||
|
- embedding
|
||||||
|
- rerank
|
||||||
|
title: ModelType
|
||||||
|
description: >-
|
||||||
|
Enumeration of supported model types in Llama Stack.
|
||||||
|
RegisterModelRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model_id:
|
||||||
|
type: string
|
||||||
|
description: The identifier of the model to register.
|
||||||
|
provider_model_id:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the model in the provider.
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
description: The identifier of the provider.
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: Any additional metadata for this model.
|
||||||
|
model_type:
|
||||||
|
$ref: '#/components/schemas/ModelType'
|
||||||
|
description: The type of model to register.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model_id
|
||||||
|
title: RegisterModelRequest
|
||||||
Model:
|
Model:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -5676,57 +5661,6 @@ components:
|
||||||
title: Model
|
title: Model
|
||||||
description: >-
|
description: >-
|
||||||
A model resource representing an AI model registered in Llama Stack.
|
A model resource representing an AI model registered in Llama Stack.
|
||||||
ModelType:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- llm
|
|
||||||
- embedding
|
|
||||||
- rerank
|
|
||||||
title: ModelType
|
|
||||||
description: >-
|
|
||||||
Enumeration of supported model types in Llama Stack.
|
|
||||||
ListModelsResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Model'
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- data
|
|
||||||
title: ListModelsResponse
|
|
||||||
RegisterModelRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
model_id:
|
|
||||||
type: string
|
|
||||||
description: The identifier of the model to register.
|
|
||||||
provider_model_id:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
The identifier of the model in the provider.
|
|
||||||
provider_id:
|
|
||||||
type: string
|
|
||||||
description: The identifier of the provider.
|
|
||||||
metadata:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: Any additional metadata for this model.
|
|
||||||
model_type:
|
|
||||||
$ref: '#/components/schemas/ModelType'
|
|
||||||
description: The type of model to register.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- model_id
|
|
||||||
title: RegisterModelRequest
|
|
||||||
RunModerationRequest:
|
RunModerationRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8144,20 +8078,6 @@ components:
|
||||||
- error
|
- error
|
||||||
title: ViolationLevel
|
title: ViolationLevel
|
||||||
description: Severity level of a safety violation.
|
description: Severity level of a safety violation.
|
||||||
AgentTurnInputType:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: agent_turn_input
|
|
||||||
default: agent_turn_input
|
|
||||||
description: >-
|
|
||||||
Discriminator type. Always "agent_turn_input"
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
title: AgentTurnInputType
|
|
||||||
description: Parameter type for agent turn input.
|
|
||||||
AggregationFunctionType:
|
AggregationFunctionType:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
|
|
@ -8400,7 +8320,6 @@ components:
|
||||||
- $ref: '#/components/schemas/UnionType'
|
- $ref: '#/components/schemas/UnionType'
|
||||||
- $ref: '#/components/schemas/ChatCompletionInputType'
|
- $ref: '#/components/schemas/ChatCompletionInputType'
|
||||||
- $ref: '#/components/schemas/CompletionInputType'
|
- $ref: '#/components/schemas/CompletionInputType'
|
||||||
- $ref: '#/components/schemas/AgentTurnInputType'
|
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
|
|
@ -8413,7 +8332,6 @@ components:
|
||||||
union: '#/components/schemas/UnionType'
|
union: '#/components/schemas/UnionType'
|
||||||
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
||||||
completion_input: '#/components/schemas/CompletionInputType'
|
completion_input: '#/components/schemas/CompletionInputType'
|
||||||
agent_turn_input: '#/components/schemas/AgentTurnInputType'
|
|
||||||
params:
|
params:
|
||||||
$ref: '#/components/schemas/ScoringFnParams'
|
$ref: '#/components/schemas/ScoringFnParams'
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
|
@ -8494,7 +8412,6 @@ components:
|
||||||
- $ref: '#/components/schemas/UnionType'
|
- $ref: '#/components/schemas/UnionType'
|
||||||
- $ref: '#/components/schemas/ChatCompletionInputType'
|
- $ref: '#/components/schemas/ChatCompletionInputType'
|
||||||
- $ref: '#/components/schemas/CompletionInputType'
|
- $ref: '#/components/schemas/CompletionInputType'
|
||||||
- $ref: '#/components/schemas/AgentTurnInputType'
|
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
|
|
@ -8507,7 +8424,6 @@ components:
|
||||||
union: '#/components/schemas/UnionType'
|
union: '#/components/schemas/UnionType'
|
||||||
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
||||||
completion_input: '#/components/schemas/CompletionInputType'
|
completion_input: '#/components/schemas/CompletionInputType'
|
||||||
agent_turn_input: '#/components/schemas/AgentTurnInputType'
|
|
||||||
RegisterScoringFunctionRequest:
|
RegisterScoringFunctionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8744,45 +8660,29 @@ components:
|
||||||
required:
|
required:
|
||||||
- shield_id
|
- shield_id
|
||||||
title: RegisterShieldRequest
|
title: RegisterShieldRequest
|
||||||
CompletionMessage:
|
InvokeToolRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
role:
|
tool_name:
|
||||||
type: string
|
type: string
|
||||||
const: assistant
|
description: The name of the tool to invoke.
|
||||||
default: assistant
|
kwargs:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
description: >-
|
description: >-
|
||||||
Must be "assistant" to identify this as the model's response
|
A dictionary of arguments to pass to the tool.
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: The content of the model's response
|
|
||||||
stop_reason:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- end_of_turn
|
|
||||||
- end_of_message
|
|
||||||
- out_of_tokens
|
|
||||||
description: >-
|
|
||||||
Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`:
|
|
||||||
The model finished generating the entire response. - `StopReason.end_of_message`:
|
|
||||||
The model finished generating but generated a partial response -- usually,
|
|
||||||
a tool call. The user may call the tool and continue the conversation
|
|
||||||
with the tool's response. - `StopReason.out_of_tokens`: The model ran
|
|
||||||
out of token budget.
|
|
||||||
tool_calls:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/ToolCall'
|
|
||||||
description: >-
|
|
||||||
List of tool calls. Each tool call is a ToolCall object.
|
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- role
|
- tool_name
|
||||||
- content
|
- kwargs
|
||||||
- stop_reason
|
title: InvokeToolRequest
|
||||||
title: CompletionMessage
|
|
||||||
description: >-
|
|
||||||
A message containing the model's (assistant) response in a chat conversation.
|
|
||||||
ImageContentItem:
|
ImageContentItem:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8829,41 +8729,6 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
image: '#/components/schemas/ImageContentItem'
|
image: '#/components/schemas/ImageContentItem'
|
||||||
text: '#/components/schemas/TextContentItem'
|
text: '#/components/schemas/TextContentItem'
|
||||||
Message:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/UserMessage'
|
|
||||||
- $ref: '#/components/schemas/SystemMessage'
|
|
||||||
- $ref: '#/components/schemas/ToolResponseMessage'
|
|
||||||
- $ref: '#/components/schemas/CompletionMessage'
|
|
||||||
discriminator:
|
|
||||||
propertyName: role
|
|
||||||
mapping:
|
|
||||||
user: '#/components/schemas/UserMessage'
|
|
||||||
system: '#/components/schemas/SystemMessage'
|
|
||||||
tool: '#/components/schemas/ToolResponseMessage'
|
|
||||||
assistant: '#/components/schemas/CompletionMessage'
|
|
||||||
SystemMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: system
|
|
||||||
default: system
|
|
||||||
description: >-
|
|
||||||
Must be "system" to identify this as a system message
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
The content of the "system prompt". If multiple system messages are provided,
|
|
||||||
they are concatenated. The underlying Llama Stack code may also add other
|
|
||||||
system messages (for example, for formatting tool definitions).
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- content
|
|
||||||
title: SystemMessage
|
|
||||||
description: >-
|
|
||||||
A system message providing instructions or context to the model.
|
|
||||||
TextContentItem:
|
TextContentItem:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8882,179 +8747,6 @@ components:
|
||||||
- text
|
- text
|
||||||
title: TextContentItem
|
title: TextContentItem
|
||||||
description: A text content item
|
description: A text content item
|
||||||
ToolCall:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
call_id:
|
|
||||||
type: string
|
|
||||||
tool_name:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
enum:
|
|
||||||
- brave_search
|
|
||||||
- wolfram_alpha
|
|
||||||
- photogen
|
|
||||||
- code_interpreter
|
|
||||||
title: BuiltinTool
|
|
||||||
- type: string
|
|
||||||
arguments:
|
|
||||||
type: string
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- call_id
|
|
||||||
- tool_name
|
|
||||||
- arguments
|
|
||||||
title: ToolCall
|
|
||||||
ToolResponseMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: tool
|
|
||||||
default: tool
|
|
||||||
description: >-
|
|
||||||
Must be "tool" to identify this as a tool response
|
|
||||||
call_id:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
Unique identifier for the tool call this response is for
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: The response content from the tool
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- call_id
|
|
||||||
- content
|
|
||||||
title: ToolResponseMessage
|
|
||||||
description: >-
|
|
||||||
A message representing the result of a tool invocation.
|
|
||||||
URL:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
uri:
|
|
||||||
type: string
|
|
||||||
description: The URL string pointing to the resource
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- uri
|
|
||||||
title: URL
|
|
||||||
description: A URL reference to external content.
|
|
||||||
UserMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: user
|
|
||||||
default: user
|
|
||||||
description: >-
|
|
||||||
Must be "user" to identify this as a user message
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
The content of the message, which can include text and other media
|
|
||||||
context:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
(Optional) This field is used internally by Llama Stack to pass RAG context.
|
|
||||||
This field may be removed in the API in the future.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- content
|
|
||||||
title: UserMessage
|
|
||||||
description: >-
|
|
||||||
A message from the user in a chat conversation.
|
|
||||||
SyntheticDataGenerateRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
dialogs:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Message'
|
|
||||||
description: >-
|
|
||||||
List of conversation messages to use as input for synthetic data generation
|
|
||||||
filtering_function:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- none
|
|
||||||
- random
|
|
||||||
- top_k
|
|
||||||
- top_p
|
|
||||||
- top_k_top_p
|
|
||||||
- sigmoid
|
|
||||||
description: >-
|
|
||||||
Type of filtering to apply to generated synthetic data samples
|
|
||||||
model:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
(Optional) The identifier of the model to use. The model must be registered
|
|
||||||
with Llama Stack and available via the /models endpoint
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- dialogs
|
|
||||||
- filtering_function
|
|
||||||
title: SyntheticDataGenerateRequest
|
|
||||||
SyntheticDataGenerationResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
synthetic_data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
List of generated synthetic data samples that passed the filtering criteria
|
|
||||||
statistics:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
(Optional) Statistical information about the generation process and filtering
|
|
||||||
results
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- synthetic_data
|
|
||||||
title: SyntheticDataGenerationResponse
|
|
||||||
description: >-
|
|
||||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
|
||||||
tuples that pass the threshold.
|
|
||||||
InvokeToolRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
tool_name:
|
|
||||||
type: string
|
|
||||||
description: The name of the tool to invoke.
|
|
||||||
kwargs:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
A dictionary of arguments to pass to the tool.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- tool_name
|
|
||||||
- kwargs
|
|
||||||
title: InvokeToolRequest
|
|
||||||
ToolInvocationResult:
|
ToolInvocationResult:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -9085,6 +8777,17 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
title: ToolInvocationResult
|
title: ToolInvocationResult
|
||||||
description: Result of a tool invocation.
|
description: Result of a tool invocation.
|
||||||
|
URL:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
uri:
|
||||||
|
type: string
|
||||||
|
description: The URL string pointing to the resource
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- uri
|
||||||
|
title: URL
|
||||||
|
description: A URL reference to external content.
|
||||||
ToolDef:
|
ToolDef:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -9155,274 +8858,6 @@ components:
|
||||||
title: ListToolDefsResponse
|
title: ListToolDefsResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response containing a list of tool definitions.
|
Response containing a list of tool definitions.
|
||||||
RAGDocument:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
document_id:
|
|
||||||
type: string
|
|
||||||
description: The unique identifier for the document.
|
|
||||||
content:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
- $ref: '#/components/schemas/InterleavedContentItem'
|
|
||||||
- type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/InterleavedContentItem'
|
|
||||||
- $ref: '#/components/schemas/URL'
|
|
||||||
description: The content of the document.
|
|
||||||
mime_type:
|
|
||||||
type: string
|
|
||||||
description: The MIME type of the document.
|
|
||||||
metadata:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: Additional metadata for the document.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- document_id
|
|
||||||
- content
|
|
||||||
- metadata
|
|
||||||
title: RAGDocument
|
|
||||||
description: >-
|
|
||||||
A document to be used for document ingestion in the RAG Tool.
|
|
||||||
InsertRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
documents:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/RAGDocument'
|
|
||||||
description: >-
|
|
||||||
List of documents to index in the RAG system
|
|
||||||
vector_store_id:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
ID of the vector database to store the document embeddings
|
|
||||||
chunk_size_in_tokens:
|
|
||||||
type: integer
|
|
||||||
description: >-
|
|
||||||
(Optional) Size in tokens for document chunking during indexing
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- documents
|
|
||||||
- vector_store_id
|
|
||||||
- chunk_size_in_tokens
|
|
||||||
title: InsertRequest
|
|
||||||
DefaultRAGQueryGeneratorConfig:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: default
|
|
||||||
default: default
|
|
||||||
description: >-
|
|
||||||
Type of query generator, always 'default'
|
|
||||||
separator:
|
|
||||||
type: string
|
|
||||||
default: ' '
|
|
||||||
description: >-
|
|
||||||
String separator used to join query terms
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- separator
|
|
||||||
title: DefaultRAGQueryGeneratorConfig
|
|
||||||
description: >-
|
|
||||||
Configuration for the default RAG query generator.
|
|
||||||
LLMRAGQueryGeneratorConfig:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: llm
|
|
||||||
default: llm
|
|
||||||
description: Type of query generator, always 'llm'
|
|
||||||
model:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
Name of the language model to use for query generation
|
|
||||||
template:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
Template string for formatting the query generation prompt
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- model
|
|
||||||
- template
|
|
||||||
title: LLMRAGQueryGeneratorConfig
|
|
||||||
description: >-
|
|
||||||
Configuration for the LLM-based RAG query generator.
|
|
||||||
RAGQueryConfig:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query_generator_config:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
|
||||||
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
|
||||||
discriminator:
|
|
||||||
propertyName: type
|
|
||||||
mapping:
|
|
||||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
|
||||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
|
||||||
description: Configuration for the query generator.
|
|
||||||
max_tokens_in_context:
|
|
||||||
type: integer
|
|
||||||
default: 4096
|
|
||||||
description: Maximum number of tokens in the context.
|
|
||||||
max_chunks:
|
|
||||||
type: integer
|
|
||||||
default: 5
|
|
||||||
description: Maximum number of chunks to retrieve.
|
|
||||||
chunk_template:
|
|
||||||
type: string
|
|
||||||
default: >
|
|
||||||
Result {index}
|
|
||||||
|
|
||||||
Content: {chunk.content}
|
|
||||||
|
|
||||||
Metadata: {metadata}
|
|
||||||
description: >-
|
|
||||||
Template for formatting each retrieved chunk in the context. Available
|
|
||||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
|
||||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
|
||||||
{chunk.content}\nMetadata: {metadata}\n"
|
|
||||||
mode:
|
|
||||||
$ref: '#/components/schemas/RAGSearchMode'
|
|
||||||
default: vector
|
|
||||||
description: >-
|
|
||||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
|
||||||
"vector".
|
|
||||||
ranker:
|
|
||||||
$ref: '#/components/schemas/Ranker'
|
|
||||||
description: >-
|
|
||||||
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
|
||||||
ranker.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- query_generator_config
|
|
||||||
- max_tokens_in_context
|
|
||||||
- max_chunks
|
|
||||||
- chunk_template
|
|
||||||
title: RAGQueryConfig
|
|
||||||
description: >-
|
|
||||||
Configuration for the RAG query generation.
|
|
||||||
RAGSearchMode:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- vector
|
|
||||||
- keyword
|
|
||||||
- hybrid
|
|
||||||
title: RAGSearchMode
|
|
||||||
description: >-
|
|
||||||
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
|
|
||||||
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
|
|
||||||
- HYBRID: Combines both vector and keyword search for better results
|
|
||||||
RRFRanker:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: rrf
|
|
||||||
default: rrf
|
|
||||||
description: The type of ranker, always "rrf"
|
|
||||||
impact_factor:
|
|
||||||
type: number
|
|
||||||
default: 60.0
|
|
||||||
description: >-
|
|
||||||
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
|
||||||
results. Must be greater than 0
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- impact_factor
|
|
||||||
title: RRFRanker
|
|
||||||
description: >-
|
|
||||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
|
||||||
Ranker:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/RRFRanker'
|
|
||||||
- $ref: '#/components/schemas/WeightedRanker'
|
|
||||||
discriminator:
|
|
||||||
propertyName: type
|
|
||||||
mapping:
|
|
||||||
rrf: '#/components/schemas/RRFRanker'
|
|
||||||
weighted: '#/components/schemas/WeightedRanker'
|
|
||||||
WeightedRanker:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: weighted
|
|
||||||
default: weighted
|
|
||||||
description: The type of ranker, always "weighted"
|
|
||||||
alpha:
|
|
||||||
type: number
|
|
||||||
default: 0.5
|
|
||||||
description: >-
|
|
||||||
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
|
||||||
only use vector scores, values in between blend both scores.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- alpha
|
|
||||||
title: WeightedRanker
|
|
||||||
description: >-
|
|
||||||
Weighted ranker configuration that combines vector and keyword scores.
|
|
||||||
QueryRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
The query content to search for in the indexed documents
|
|
||||||
vector_store_ids:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
List of vector database IDs to search within
|
|
||||||
query_config:
|
|
||||||
$ref: '#/components/schemas/RAGQueryConfig'
|
|
||||||
description: >-
|
|
||||||
(Optional) Configuration parameters for the query operation
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- content
|
|
||||||
- vector_store_ids
|
|
||||||
title: QueryRequest
|
|
||||||
RAGQueryResult:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
(Optional) The retrieved content from the query
|
|
||||||
metadata:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
Additional metadata about the query result
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- metadata
|
|
||||||
title: RAGQueryResult
|
|
||||||
description: >-
|
|
||||||
Result of a RAG query containing retrieved content and metadata.
|
|
||||||
ToolGroup:
|
ToolGroup:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -10686,8 +10121,6 @@ tags:
|
||||||
description: ''
|
description: ''
|
||||||
- name: Shields
|
- name: Shields
|
||||||
description: ''
|
description: ''
|
||||||
- name: SyntheticDataGeneration (Coming Soon)
|
|
||||||
description: ''
|
|
||||||
- name: ToolGroups
|
- name: ToolGroups
|
||||||
description: ''
|
description: ''
|
||||||
- name: ToolRuntime
|
- name: ToolRuntime
|
||||||
|
|
@ -10710,7 +10143,6 @@ x-tagGroups:
|
||||||
- Scoring
|
- Scoring
|
||||||
- ScoringFunctions
|
- ScoringFunctions
|
||||||
- Shields
|
- Shields
|
||||||
- SyntheticDataGeneration (Coming Soon)
|
|
||||||
- ToolGroups
|
- ToolGroups
|
||||||
- ToolRuntime
|
- ToolRuntime
|
||||||
- VectorIO
|
- VectorIO
|
||||||
|
|
|
||||||
2579
docs/static/stainless-llama-stack-spec.yaml
vendored
2579
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -186,11 +186,35 @@ if ! command -v pytest &>/dev/null; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Helper function to find next available port
|
||||||
|
find_available_port() {
|
||||||
|
local start_port=$1
|
||||||
|
local port=$start_port
|
||||||
|
for ((i=0; i<100; i++)); do
|
||||||
|
if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then
|
||||||
|
echo $port
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
((port++))
|
||||||
|
done
|
||||||
|
echo "Failed to find available port starting from $start_port" >&2
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
# Start Llama Stack Server if needed
|
# Start Llama Stack Server if needed
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
|
# Find an available port for the server
|
||||||
|
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "Error: $LLAMA_STACK_PORT"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
export LLAMA_STACK_PORT
|
||||||
|
echo "Will use port: $LLAMA_STACK_PORT"
|
||||||
|
|
||||||
stop_server() {
|
stop_server() {
|
||||||
echo "Stopping Llama Stack Server..."
|
echo "Stopping Llama Stack Server..."
|
||||||
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
pids=$(lsof -i :$LLAMA_STACK_PORT | awk 'NR>1 {print $2}')
|
||||||
if [[ -n "$pids" ]]; then
|
if [[ -n "$pids" ]]; then
|
||||||
echo "Killing Llama Stack Server processes: $pids"
|
echo "Killing Llama Stack Server processes: $pids"
|
||||||
kill -9 $pids
|
kill -9 $pids
|
||||||
|
|
@ -200,10 +224,6 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
echo "Llama Stack Server stopped"
|
echo "Llama Stack Server stopped"
|
||||||
}
|
}
|
||||||
|
|
||||||
# check if server is already running
|
|
||||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
|
||||||
echo "Llama Stack Server is already running, skipping start"
|
|
||||||
else
|
|
||||||
echo "=== Starting Llama Stack Server ==="
|
echo "=== Starting Llama Stack Server ==="
|
||||||
export LLAMA_STACK_LOG_WIDTH=120
|
export LLAMA_STACK_LOG_WIDTH=120
|
||||||
|
|
||||||
|
|
@ -220,9 +240,9 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||||
nohup llama stack run $stack_config >server.log 2>&1 &
|
nohup llama stack run $stack_config >server.log 2>&1 &
|
||||||
|
|
||||||
echo "Waiting for Llama Stack Server to start..."
|
echo "Waiting for Llama Stack Server to start on port $LLAMA_STACK_PORT..."
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
if curl -s http://localhost:$LLAMA_STACK_PORT/v1/health 2>/dev/null | grep -q "OK"; then
|
||||||
echo "✅ Llama Stack Server started successfully"
|
echo "✅ Llama Stack Server started successfully"
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
|
|
@ -235,7 +255,6 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo ""
|
echo ""
|
||||||
fi
|
|
||||||
|
|
||||||
trap stop_server EXIT ERR INT TERM
|
trap stop_server EXIT ERR INT TERM
|
||||||
fi
|
fi
|
||||||
|
|
@ -259,7 +278,14 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
|
|
||||||
# Extract distribution name from docker:distro format
|
# Extract distribution name from docker:distro format
|
||||||
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
|
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
|
||||||
export LLAMA_STACK_PORT=8321
|
# Find an available port for the docker container
|
||||||
|
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "Error: $LLAMA_STACK_PORT"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
export LLAMA_STACK_PORT
|
||||||
|
echo "Will use port: $LLAMA_STACK_PORT"
|
||||||
|
|
||||||
echo "=== Building Docker Image for distribution: $DISTRO ==="
|
echo "=== Building Docker Image for distribution: $DISTRO ==="
|
||||||
containerfile="$ROOT_DIR/containers/Containerfile"
|
containerfile="$ROOT_DIR/containers/Containerfile"
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
cd src/llama_stack/ui
|
cd src/llama_stack_ui
|
||||||
|
|
||||||
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
|
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
|
||||||
echo "UI dependencies not installed, skipping prettier/linter check"
|
echo "UI dependencies not installed, skipping prettier/linter check"
|
||||||
|
|
|
||||||
|
|
@ -5,30 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from typing import Annotated, Protocol, runtime_checkable
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, webmethod
|
||||||
CompletionMessage,
|
|
||||||
ResponseFormat,
|
|
||||||
SamplingParams,
|
|
||||||
ToolCall,
|
|
||||||
ToolChoice,
|
|
||||||
ToolConfig,
|
|
||||||
ToolPromptFormat,
|
|
||||||
ToolResponse,
|
|
||||||
ToolResponseMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
|
||||||
from llama_stack.apis.tools import ToolDef
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
|
||||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
from .openai_responses import (
|
from .openai_responses import (
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
|
|
@ -57,729 +40,12 @@ class ResponseGuardrailSpec(BaseModel):
|
||||||
ResponseGuardrail = str | ResponseGuardrailSpec
|
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
|
||||||
"""An attachment to an agent turn.
|
|
||||||
|
|
||||||
:param content: The content of the attachment.
|
|
||||||
:param mime_type: The MIME type of the attachment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
|
||||||
"""A document to be used by an agent.
|
|
||||||
|
|
||||||
:param content: The content of the document.
|
|
||||||
:param mime_type: The MIME type of the document.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
|
||||||
"""A common step in an agent turn.
|
|
||||||
|
|
||||||
:param turn_id: The ID of the turn.
|
|
||||||
:param step_id: The ID of the step.
|
|
||||||
:param started_at: The time the step started.
|
|
||||||
:param completed_at: The time the step completed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
step_id: str
|
|
||||||
started_at: datetime | None = None
|
|
||||||
completed_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StepType(StrEnum):
|
|
||||||
"""Type of the step in an agent turn.
|
|
||||||
|
|
||||||
:cvar inference: The step is an inference step that calls an LLM.
|
|
||||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
|
||||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
|
||||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
inference = "inference"
|
|
||||||
tool_execution = "tool_execution"
|
|
||||||
shield_call = "shield_call"
|
|
||||||
memory_retrieval = "memory_retrieval"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class InferenceStep(StepCommon):
|
|
||||||
"""An inference step in an agent turn.
|
|
||||||
|
|
||||||
:param model_response: The response from the LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
step_type: Literal[StepType.inference] = StepType.inference
|
|
||||||
model_response: CompletionMessage
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolExecutionStep(StepCommon):
|
|
||||||
"""A tool execution step in an agent turn.
|
|
||||||
|
|
||||||
:param tool_calls: The tool calls to execute.
|
|
||||||
:param tool_responses: The tool responses from the tool calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
|
||||||
tool_calls: list[ToolCall]
|
|
||||||
tool_responses: list[ToolResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ShieldCallStep(StepCommon):
|
|
||||||
"""A shield call step in an agent turn.
|
|
||||||
|
|
||||||
:param violation: The violation from the shield call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
|
||||||
violation: SafetyViolation | None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryRetrievalStep(StepCommon):
|
|
||||||
"""A memory retrieval step in an agent turn.
|
|
||||||
|
|
||||||
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
|
|
||||||
:param inserted_context: The context retrieved from the vector databases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
|
||||||
# TODO: should this be List[str]?
|
|
||||||
vector_store_ids: str
|
|
||||||
inserted_context: InterleavedContent
|
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
|
||||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
|
||||||
Field(discriminator="step_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Turn(BaseModel):
|
|
||||||
"""A single turn in an interaction with an Agentic System.
|
|
||||||
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param input_messages: List of messages that initiated this turn
|
|
||||||
:param steps: Ordered list of processing steps executed during this turn
|
|
||||||
:param output_message: The model's generated response containing content and metadata
|
|
||||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
|
||||||
:param started_at: Timestamp when the turn began
|
|
||||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
|
||||||
"""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
session_id: str
|
|
||||||
input_messages: list[UserMessage | ToolResponseMessage]
|
|
||||||
steps: list[Step]
|
|
||||||
output_message: CompletionMessage
|
|
||||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
started_at: datetime
|
|
||||||
completed_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Session(BaseModel):
|
|
||||||
"""A single session of an interaction with an Agentic System.
|
|
||||||
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param session_name: Human-readable name for the session
|
|
||||||
:param turns: List of all turns that have occurred in this session
|
|
||||||
:param started_at: Timestamp when the session was created
|
|
||||||
"""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
turns: list[Turn]
|
|
||||||
started_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
|
||||||
name: str
|
|
||||||
args: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
|
||||||
|
|
||||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
|
||||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
|
||||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
|
||||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
|
||||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
|
||||||
tool_config: ToolConfig | None = Field(default=None)
|
|
||||||
|
|
||||||
max_infer_iters: int | None = 10
|
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
|
||||||
if self.tool_config:
|
|
||||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
|
||||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
|
||||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
|
||||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
if self.tool_choice:
|
|
||||||
params["tool_choice"] = self.tool_choice
|
|
||||||
if self.tool_prompt_format:
|
|
||||||
params["tool_prompt_format"] = self.tool_prompt_format
|
|
||||||
self.tool_config = ToolConfig(**params)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentConfig(AgentConfigCommon):
|
|
||||||
"""Configuration for an agent.
|
|
||||||
|
|
||||||
:param model: The model identifier to use for the agent
|
|
||||||
:param instructions: The system instructions for the agent
|
|
||||||
:param name: Optional name for the agent, used in telemetry and identification
|
|
||||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
|
||||||
:param response_format: Optional response format configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str
|
|
||||||
instructions: str
|
|
||||||
name: str | None = None
|
|
||||||
enable_session_persistence: bool | None = False
|
|
||||||
response_format: ResponseFormat | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Agent(BaseModel):
|
|
||||||
"""An agent instance with configuration and metadata.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param agent_config: Configuration settings for the agent
|
|
||||||
:param created_at: Timestamp when the agent was created
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
agent_config: AgentConfig
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|
||||||
instructions: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(StrEnum):
|
|
||||||
step_start = "step_start"
|
|
||||||
step_complete = "step_complete"
|
|
||||||
step_progress = "step_progress"
|
|
||||||
|
|
||||||
turn_start = "turn_start"
|
|
||||||
turn_complete = "turn_complete"
|
|
||||||
turn_awaiting_input = "turn_awaiting_input"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
|
||||||
"""Payload for step start events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param metadata: (Optional) Additional metadata for the step
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
|
||||||
"""Payload for step completion events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param step_details: Complete details of the executed step
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
step_details: Step
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
|
||||||
"""Payload for step progress events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param step_type: Type of step being executed
|
|
||||||
:param step_id: Unique identifier for the step within a turn
|
|
||||||
:param delta: Incremental content changes during step execution
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
|
|
||||||
delta: ContentDelta
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
|
||||||
"""Payload for turn start events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
|
||||||
turn_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|
||||||
"""Payload for turn completion events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn: Complete turn data including all steps and results
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
|
||||||
turn: Turn
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|
||||||
"""Payload for turn awaiting input events in agent turn responses.
|
|
||||||
|
|
||||||
:param event_type: Type of event being reported
|
|
||||||
:param turn: Turn data when waiting for external tool responses
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
|
||||||
turn: Turn
|
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
|
||||||
AgentTurnResponseStepStartPayload
|
|
||||||
| AgentTurnResponseStepProgressPayload
|
|
||||||
| AgentTurnResponseStepCompletePayload
|
|
||||||
| AgentTurnResponseTurnStartPayload
|
|
||||||
| AgentTurnResponseTurnCompletePayload
|
|
||||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
Field(discriminator="event_type"),
|
|
||||||
]
|
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseEvent(BaseModel):
|
|
||||||
"""An event in an agent turn response stream.
|
|
||||||
|
|
||||||
:param payload: Event-specific payload containing event data
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload: AgentTurnResponseEventPayload
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentCreateResponse(BaseModel):
|
|
||||||
"""Response returned when creating a new agent.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the created agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentSessionCreateResponse(BaseModel):
|
|
||||||
"""Response returned when creating a new agent session.
|
|
||||||
|
|
||||||
:param session_id: Unique identifier for the created session
|
|
||||||
"""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|
||||||
"""Request to create a new turn for an agent.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param messages: List of messages to start the turn with
|
|
||||||
:param documents: (Optional) List of documents to provide to the agent
|
|
||||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
|
||||||
:param stream: (Optional) Whether to stream the response
|
|
||||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
# TODO: figure out how we can simplify this and make why
|
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
|
||||||
# execution from outside the system)
|
|
||||||
messages: list[UserMessage | ToolResponseMessage]
|
|
||||||
|
|
||||||
documents: list[Document] | None = None
|
|
||||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
stream: bool | None = False
|
|
||||||
tool_config: ToolConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResumeRequest(BaseModel):
|
|
||||||
"""Request to resume an agent turn with tool responses.
|
|
||||||
|
|
||||||
:param agent_id: Unique identifier for the agent
|
|
||||||
:param session_id: Unique identifier for the conversation session
|
|
||||||
:param turn_id: Unique identifier for the turn within a session
|
|
||||||
:param tool_responses: List of tool responses to submit to continue the turn
|
|
||||||
:param stream: (Optional) Whether to stream the response
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
session_id: str
|
|
||||||
turn_id: str
|
|
||||||
tool_responses: list[ToolResponse]
|
|
||||||
stream: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
|
||||||
"""Streamed agent turn completion response.
|
|
||||||
|
|
||||||
:param event: Individual event in the agent turn response stream
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: AgentTurnResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentStepResponse(BaseModel):
|
|
||||||
"""Response containing details of a specific agent step.
|
|
||||||
|
|
||||||
:param step: The complete step data and execution details
|
|
||||||
"""
|
|
||||||
|
|
||||||
step: Step
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Agents(Protocol):
|
class Agents(Protocol):
|
||||||
"""Agents
|
"""Agents
|
||||||
|
|
||||||
APIs for creating and interacting with agentic systems."""
|
APIs for creating and interacting with agentic systems."""
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent(
|
|
||||||
self,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
) -> AgentCreateResponse:
|
|
||||||
"""Create an agent with the given configuration.
|
|
||||||
|
|
||||||
:param agent_config: The configuration for the agent.
|
|
||||||
:returns: An AgentCreateResponse with the agent ID.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_turn",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_turn",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
messages: list[UserMessage | ToolResponseMessage],
|
|
||||||
stream: bool | None = False,
|
|
||||||
documents: list[Document] | None = None,
|
|
||||||
toolgroups: list[AgentToolGroup] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
|
||||||
"""Create a new turn for an agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
|
||||||
:param session_id: The ID of the session to create the turn for.
|
|
||||||
:param messages: List of messages to start the turn with.
|
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
|
||||||
:param documents: (Optional) List of documents to create the turn with.
|
|
||||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
|
||||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
|
||||||
:returns: If stream=False, returns a Turn object.
|
|
||||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="resume_agent_turn",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="resume_agent_turn",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def resume_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
tool_responses: list[ToolResponse],
|
|
||||||
stream: bool | None = False,
|
|
||||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
|
||||||
"""Resume an agent turn with executed tool call responses.
|
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to resume.
|
|
||||||
:param session_id: The ID of the session to resume.
|
|
||||||
:param turn_id: The ID of the turn to resume.
|
|
||||||
:param tool_responses: The tool call responses to resume the turn with.
|
|
||||||
:param stream: Whether to stream the response.
|
|
||||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
|
||||||
method="GET",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
) -> Turn:
|
|
||||||
"""Retrieve an agent turn by its ID.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to get the turn for.
|
|
||||||
:param session_id: The ID of the session to get the turn for.
|
|
||||||
:param turn_id: The ID of the turn to get.
|
|
||||||
:returns: A Turn.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
|
||||||
method="GET",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_step(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
step_id: str,
|
|
||||||
) -> AgentStepResponse:
|
|
||||||
"""Retrieve an agent step by its ID.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to get the step for.
|
|
||||||
:param session_id: The ID of the session to get the step for.
|
|
||||||
:param turn_id: The ID of the turn to get the step for.
|
|
||||||
:param step_id: The ID of the step to get.
|
|
||||||
:returns: An AgentStepResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_session",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session",
|
|
||||||
method="POST",
|
|
||||||
descriptive_name="create_agent_session",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def create_agent_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_name: str,
|
|
||||||
) -> AgentSessionCreateResponse:
|
|
||||||
"""Create a new session for an agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the session for.
|
|
||||||
:param session_name: The name of the session to create.
|
|
||||||
:returns: An AgentSessionCreateResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="GET",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def get_agents_session(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
turn_ids: list[str] | None = None,
|
|
||||||
) -> Session:
|
|
||||||
"""Retrieve an agent session by its ID.
|
|
||||||
|
|
||||||
:param session_id: The ID of the session to get.
|
|
||||||
:param agent_id: The ID of the agent to get the session for.
|
|
||||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
|
||||||
:returns: A Session.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="DELETE",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/session/{session_id}",
|
|
||||||
method="DELETE",
|
|
||||||
level=LLAMA_STACK_API_V1ALPHA,
|
|
||||||
)
|
|
||||||
async def delete_agents_session(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete an agent session by its ID and its associated turns.
|
|
||||||
|
|
||||||
:param session_id: The ID of the session to delete.
|
|
||||||
:param agent_id: The ID of the agent to delete the session for.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}",
|
|
||||||
method="DELETE",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def delete_agent(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete an agent by its ID and its associated sessions and turns.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to delete.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
|
||||||
"""List all agents.
|
|
||||||
|
|
||||||
:param start_index: The index to start the pagination from.
|
|
||||||
:param limit: The number of agents to return.
|
|
||||||
:returns: A PaginatedResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}",
|
|
||||||
method="GET",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def get_agent(self, agent_id: str) -> Agent:
|
|
||||||
"""Describe an agent by its ID.
|
|
||||||
|
|
||||||
:param agent_id: ID of the agent.
|
|
||||||
:returns: An Agent of the agent.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/agents/{agent_id}/sessions",
|
|
||||||
method="GET",
|
|
||||||
deprecated=True,
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
)
|
|
||||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
||||||
async def list_agent_sessions(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
"""List all session(s) of a given agent.
|
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to list sessions for.
|
|
||||||
:param start_index: The index to start the pagination from.
|
|
||||||
:param limit: The number of sessions to return.
|
|
||||||
:returns: A PaginatedResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||||
|
|
@ -787,12 +53,6 @@ class Agents(Protocol):
|
||||||
#
|
#
|
||||||
# Both of these APIs are inherently stateful.
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/responses/{response_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -805,7 +65,6 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -842,7 +101,6 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
self,
|
self,
|
||||||
|
|
@ -861,9 +119,6 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_openai_response_input_items(
|
async def list_openai_response_input_items(
|
||||||
self,
|
self,
|
||||||
|
|
@ -886,7 +141,6 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
"""Delete a response.
|
"""Delete a response.
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,6 @@ class Batches(Protocol):
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def create_batch(
|
async def create_batch(
|
||||||
self,
|
self,
|
||||||
|
|
@ -64,7 +63,6 @@ class Batches(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
"""Retrieve information about a specific batch.
|
"""Retrieve information about a specific batch.
|
||||||
|
|
@ -74,7 +72,6 @@ class Batches(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
"""Cancel a batch that is in progress.
|
"""Cancel a batch that is in progress.
|
||||||
|
|
@ -84,7 +81,6 @@ class Batches(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_batches(
|
async def list_batches(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +54,6 @@ class ListBenchmarksResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Benchmarks(Protocol):
|
class Benchmarks(Protocol):
|
||||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
"""List all benchmarks.
|
"""List all benchmarks.
|
||||||
|
|
@ -63,7 +62,6 @@ class Benchmarks(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def get_benchmark(
|
async def get_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -76,7 +74,6 @@ class Benchmarks(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def register_benchmark(
|
async def register_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -98,7 +95,6 @@ class Benchmarks(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||||
"""Unregister a benchmark.
|
"""Unregister a benchmark.
|
||||||
|
|
|
||||||
|
|
@ -56,14 +56,6 @@ class ToolGroupNotFoundError(ResourceNotFoundError):
|
||||||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||||
|
|
||||||
|
|
||||||
class SessionNotFoundError(ValueError):
|
|
||||||
"""raised when Llama Stack cannot find a referenced session or access is denied"""
|
|
||||||
|
|
||||||
def __init__(self, session_name: str) -> None:
|
|
||||||
message = f"Session '{session_name}' not found or access denied."
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeError(TypeError):
|
class ModelTypeError(TypeError):
|
||||||
"""raised when a model is present but not the correct type"""
|
"""raised when a model is present but not the correct type"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,17 +103,6 @@ class CompletionInputType(BaseModel):
|
||||||
type: Literal["completion_input"] = "completion_input"
|
type: Literal["completion_input"] = "completion_input"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgentTurnInputType(BaseModel):
|
|
||||||
"""Parameter type for agent turn input.
|
|
||||||
|
|
||||||
:param type: Discriminator type. Always "agent_turn_input"
|
|
||||||
"""
|
|
||||||
|
|
||||||
# expects List[Message] for messages (may also include attachments?)
|
|
||||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DialogType(BaseModel):
|
class DialogType(BaseModel):
|
||||||
"""Parameter type for dialog data with semantic output labels.
|
"""Parameter type for dialog data with semantic output labels.
|
||||||
|
|
@ -135,8 +124,7 @@ ParamType = Annotated[
|
||||||
| JsonType
|
| JsonType
|
||||||
| UnionType
|
| UnionType
|
||||||
| ChatCompletionInputType
|
| ChatCompletionInputType
|
||||||
| CompletionInputType
|
| CompletionInputType,
|
||||||
| AgentTurnInputType,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ParamType, name="ParamType")
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +21,6 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
|
|
@ -46,9 +45,6 @@ class DatasetIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
|
|
||||||
)
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
"""Append rows to a dataset.
|
"""Append rows to a dataset.
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -146,7 +146,6 @@ class ListDatasetsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
|
|
@ -216,7 +215,6 @@ class Datasets(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
|
|
@ -229,7 +227,6 @@ class Datasets(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
"""List all datasets.
|
"""List all datasets.
|
||||||
|
|
@ -238,7 +235,6 @@ class Datasets(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, Any, Literal, Protocol
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
from llama_stack.apis.scoring import ScoringResult
|
from llama_stack.apis.scoring import ScoringResult
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -32,19 +31,7 @@ class ModelCandidate(BaseModel):
|
||||||
system_message: SystemMessage | None = None
|
system_message: SystemMessage | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
EvalCandidate = ModelCandidate
|
||||||
class AgentCandidate(BaseModel):
|
|
||||||
"""An agent candidate for evaluation.
|
|
||||||
|
|
||||||
:param config: The configuration for the agent candidate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["agent"] = "agent"
|
|
||||||
config: AgentConfig
|
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -86,7 +73,6 @@ class Eval(Protocol):
|
||||||
|
|
||||||
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
|
|
@ -101,9 +87,6 @@ class Eval(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
|
|
@ -122,9 +105,6 @@ class Eval(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
"""Get the status of a job.
|
"""Get the status of a job.
|
||||||
|
|
@ -135,12 +115,6 @@ class Eval(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
|
||||||
method="DELETE",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
"""Cancel a job.
|
"""Cancel a job.
|
||||||
|
|
@ -150,12 +124,6 @@ class Eval(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,6 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# OpenAI Files API Endpoints
|
# OpenAI Files API Endpoints
|
||||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_upload_file(
|
async def openai_upload_file(
|
||||||
self,
|
self,
|
||||||
|
|
@ -134,7 +133,6 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_list_files(
|
async def openai_list_files(
|
||||||
self,
|
self,
|
||||||
|
|
@ -155,7 +153,6 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_retrieve_file(
|
async def openai_retrieve_file(
|
||||||
self,
|
self,
|
||||||
|
|
@ -170,7 +167,6 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_delete_file(
|
async def openai_delete_file(
|
||||||
self,
|
self,
|
||||||
|
|
@ -183,7 +179,6 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_retrieve_file_content(
|
async def openai_retrieve_file_content(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1189,7 +1189,6 @@ class InferenceProvider(Protocol):
|
||||||
raise NotImplementedError("Reranking is not implemented")
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
return # this is so mypy's safe-super rule will consider the method concrete
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1202,7 +1201,6 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1215,7 +1213,6 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1240,7 +1237,6 @@ class Inference(InferenceProvider):
|
||||||
- Rerank models: these models reorder the documents based on their relevance to a query.
|
- Rerank models: these models reorder the documents based on their relevance to a query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_chat_completions(
|
async def list_chat_completions(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1259,9 +1255,6 @@ class Inference(InferenceProvider):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("List chat completions is not implemented")
|
raise NotImplementedError("List chat completions is not implemented")
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
"""Get chat completion.
|
"""Get chat completion.
|
||||||
|
|
|
||||||
|
|
@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
|
||||||
:object: The object type, which will be "model"
|
:object: The object type, which will be "model"
|
||||||
:created: The Unix timestamp in seconds when the model was created
|
:created: The Unix timestamp in seconds when the model was created
|
||||||
:owned_by: The owner of the model
|
:owned_by: The owner of the model
|
||||||
|
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
object: Literal["model"] = "model"
|
object: Literal["model"] = "model"
|
||||||
created: int
|
created: int
|
||||||
owned_by: str
|
owned_by: str
|
||||||
|
custom_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
|
|
@ -105,7 +107,6 @@ class OpenAIListModelsResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
"""List all models.
|
"""List all models.
|
||||||
|
|
||||||
|
|
@ -113,7 +114,7 @@ class Models(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
"""List models using the OpenAI API.
|
"""List models using the OpenAI API.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -284,7 +284,6 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
|
@ -312,7 +311,6 @@ class PostTraining(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
|
@ -335,7 +333,6 @@ class PostTraining(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
"""Get all training jobs.
|
"""Get all training jobs.
|
||||||
|
|
@ -344,7 +341,6 @@ class PostTraining(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
"""Get the status of a training job.
|
"""Get the status of a training job.
|
||||||
|
|
@ -354,7 +350,6 @@ class PostTraining(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
"""Cancel a training job.
|
"""Cancel a training job.
|
||||||
|
|
@ -363,7 +358,6 @@ class PostTraining(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
"""Get the artifacts of a training job.
|
"""Get the artifacts of a training job.
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,6 @@ class Safety(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
"""Create moderation.
|
"""Create moderation.
|
||||||
|
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .synthetic_data_generation import *
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
|
||||||
"""The type of filtering function.
|
|
||||||
|
|
||||||
:cvar none: No filtering applied, accept all generated synthetic data
|
|
||||||
:cvar random: Random sampling of generated data points
|
|
||||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
|
||||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
|
||||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
|
||||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
|
||||||
"""
|
|
||||||
|
|
||||||
none = "none"
|
|
||||||
random = "random"
|
|
||||||
top_k = "top_k"
|
|
||||||
top_p = "top_p"
|
|
||||||
top_k_top_p = "top_k_top_p"
|
|
||||||
sigmoid = "sigmoid"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
|
||||||
|
|
||||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
|
||||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
|
||||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
dialogs: list[Message]
|
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
|
||||||
model: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
|
||||||
|
|
||||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
|
||||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
|
||||||
"""
|
|
||||||
|
|
||||||
synthetic_data: list[dict[str, Any]]
|
|
||||||
statistics: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
|
||||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
|
||||||
def synthetic_data_generate(
|
|
||||||
self,
|
|
||||||
dialogs: list[Message],
|
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
|
||||||
model: str | None = None,
|
|
||||||
) -> SyntheticDataGenerationResponse:
|
|
||||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
|
|
||||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
|
||||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
|
||||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
|
||||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
@ -5,18 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Annotated, Any, Literal, Protocol
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import runtime_checkable
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RRFRanker(BaseModel):
|
class RRFRanker(BaseModel):
|
||||||
"""
|
"""
|
||||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||||
|
|
@ -30,7 +25,6 @@ class RRFRanker(BaseModel):
|
||||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class WeightedRanker(BaseModel):
|
class WeightedRanker(BaseModel):
|
||||||
"""
|
"""
|
||||||
Weighted ranker configuration that combines vector and keyword scores.
|
Weighted ranker configuration that combines vector and keyword scores.
|
||||||
|
|
@ -55,10 +49,8 @@ Ranker = Annotated[
|
||||||
RRFRanker | WeightedRanker,
|
RRFRanker | WeightedRanker,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(Ranker, name="Ranker")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RAGDocument(BaseModel):
|
class RAGDocument(BaseModel):
|
||||||
"""
|
"""
|
||||||
A document to be used for document ingestion in the RAG Tool.
|
A document to be used for document ingestion in the RAG Tool.
|
||||||
|
|
@ -75,7 +67,6 @@ class RAGDocument(BaseModel):
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
"""Result of a RAG query containing retrieved content and metadata.
|
"""Result of a RAG query containing retrieved content and metadata.
|
||||||
|
|
||||||
|
|
@ -87,7 +78,6 @@ class RAGQueryResult(BaseModel):
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RAGQueryGenerator(Enum):
|
class RAGQueryGenerator(Enum):
|
||||||
"""Types of query generators for RAG systems.
|
"""Types of query generators for RAG systems.
|
||||||
|
|
||||||
|
|
@ -101,7 +91,6 @@ class RAGQueryGenerator(Enum):
|
||||||
custom = "custom"
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RAGSearchMode(StrEnum):
|
class RAGSearchMode(StrEnum):
|
||||||
"""
|
"""
|
||||||
Search modes for RAG query retrieval:
|
Search modes for RAG query retrieval:
|
||||||
|
|
@ -115,7 +104,6 @@ class RAGSearchMode(StrEnum):
|
||||||
HYBRID = "hybrid"
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||||
"""Configuration for the default RAG query generator.
|
"""Configuration for the default RAG query generator.
|
||||||
|
|
||||||
|
|
@ -127,7 +115,6 @@ class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||||
separator: str = " "
|
separator: str = " "
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
"""Configuration for the LLM-based RAG query generator.
|
"""Configuration for the LLM-based RAG query generator.
|
||||||
|
|
||||||
|
|
@ -145,10 +132,8 @@ RAGQueryGeneratorConfig = Annotated[
|
||||||
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RAGQueryConfig(BaseModel):
|
class RAGQueryConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Configuration for the RAG query generation.
|
Configuration for the RAG query generation.
|
||||||
|
|
@ -181,38 +166,3 @@ class RAGQueryConfig(BaseModel):
|
||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
raise ValueError("chunk_template must not be empty")
|
raise ValueError("chunk_template must not be empty")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class RAGToolRuntime(Protocol):
|
|
||||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_store_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
"""Index documents so they can be used by the RAG system.
|
|
||||||
|
|
||||||
:param documents: List of documents to index in the RAG system
|
|
||||||
:param vector_store_id: ID of the vector database to store the document embeddings
|
|
||||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_store_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
"""Query the RAG system for context; typically invoked by the agent.
|
|
||||||
|
|
||||||
:param content: The query content to search for in the indexed documents
|
|
||||||
:param vector_store_ids: List of vector database IDs to search within
|
|
||||||
:param query_config: (Optional) Configuration parameters for the query operation
|
|
||||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from .rag_tool import RAGToolRuntime
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
|
|
@ -195,8 +193,6 @@ class SpecialToolGroup(Enum):
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore | None = None
|
tool_store: ToolStore | None = None
|
||||||
|
|
||||||
rag_tool: RAGToolRuntime | None = None
|
|
||||||
|
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
||||||
|
|
@ -545,7 +545,6 @@ class VectorIO(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
# OpenAI Vector Stores API endpoints
|
# OpenAI Vector Stores API endpoints
|
||||||
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
|
|
@ -558,7 +557,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
||||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
|
|
@ -577,9 +575,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_retrieve_vector_store(
|
async def openai_retrieve_vector_store(
|
||||||
self,
|
self,
|
||||||
|
|
@ -592,9 +587,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}",
|
route="/vector_stores/{vector_store_id}",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|
@ -617,9 +609,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}",
|
route="/vector_stores/{vector_store_id}",
|
||||||
method="DELETE",
|
method="DELETE",
|
||||||
|
|
@ -636,12 +625,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/search",
|
|
||||||
method="POST",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/search",
|
route="/vector_stores/{vector_store_id}/search",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|
@ -674,12 +657,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
|
||||||
method="POST",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files",
|
route="/vector_stores/{vector_store_id}/files",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|
@ -702,12 +679,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files",
|
route="/vector_stores/{vector_store_id}/files",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
@ -734,12 +705,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
@ -758,12 +723,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
@ -782,12 +741,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
|
||||||
method="POST",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|
@ -808,12 +761,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
|
||||||
method="DELETE",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||||
method="DELETE",
|
method="DELETE",
|
||||||
|
|
@ -837,12 +784,6 @@ class VectorIO(Protocol):
|
||||||
method="POST",
|
method="POST",
|
||||||
level=LLAMA_STACK_API_V1,
|
level=LLAMA_STACK_API_V1,
|
||||||
)
|
)
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
|
|
||||||
method="POST",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
async def openai_create_vector_store_file_batch(
|
async def openai_create_vector_store_file_batch(
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
|
|
@ -861,12 +802,6 @@ class VectorIO(Protocol):
|
||||||
method="GET",
|
method="GET",
|
||||||
level=LLAMA_STACK_API_V1,
|
level=LLAMA_STACK_API_V1,
|
||||||
)
|
)
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
async def openai_retrieve_vector_store_file_batch(
|
async def openai_retrieve_vector_store_file_batch(
|
||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
|
|
@ -880,12 +815,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
|
||||||
method="GET",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
@ -914,12 +843,6 @@ class VectorIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(
|
|
||||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
|
||||||
method="POST",
|
|
||||||
level=LLAMA_STACK_API_V1,
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|
|
||||||
|
|
@ -253,7 +253,7 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
ui_dir = REPO_ROOT / "llama_stack_ui"
|
||||||
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
||||||
try:
|
try:
|
||||||
# Create logs directory if it doesn't exist
|
# Create logs directory if it doesn't exist
|
||||||
|
|
|
||||||
|
|
@ -8,14 +8,9 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
RAGQueryResult,
|
|
||||||
RAGToolRuntime,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -26,36 +21,6 @@ logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
class RagToolImpl(RAGToolRuntime):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: ToolGroupsRoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_store_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
|
||||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
|
||||||
return await provider.query(content, vector_store_ids, query_config)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_store_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
|
||||||
)
|
|
||||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
|
||||||
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: ToolGroupsRoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
|
|
@ -63,11 +28,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
|
||||||
self.rag_tool = self.RagToolImpl(routing_table)
|
|
||||||
for method in ("query", "insert"):
|
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
logger.debug("ToolRuntimeRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
object="model",
|
object="model",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
owned_by="llama_stack",
|
owned_by="llama_stack",
|
||||||
|
custom_metadata={
|
||||||
|
"model_type": model.model_type,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
"provider_resource_id": model.provider_resource_id,
|
||||||
|
**model.metadata,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
for model in all_models
|
for model in all_models
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from aiohttp import hdrs
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.resolver import api_protocol_map
|
||||||
from llama_stack.schema_utils import WebMethod
|
from llama_stack.schema_utils import WebMethod
|
||||||
|
|
||||||
|
|
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
|
||||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
|
||||||
return {
|
|
||||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_routes(
|
def get_all_api_routes(
|
||||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map(external_apis)
|
protocols = api_protocol_map(external_apis)
|
||||||
toolgroup_protocols = toolgroup_protocol_map()
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
routes = []
|
routes = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
# HACK ALERT
|
|
||||||
if api == Api.tool_runtime:
|
|
||||||
for tool_group in SpecialToolGroup:
|
|
||||||
sub_protocol = toolgroup_protocols[tool_group]
|
|
||||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
|
||||||
for name, method in sub_protocol_methods:
|
|
||||||
if not hasattr(method, "__webmethod__"):
|
|
||||||
continue
|
|
||||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
for name, method in protocol_methods:
|
||||||
# Get all webmethods for this method (supports multiple decorators)
|
# Get all webmethods for this method (supports multiple decorators)
|
||||||
webmethods = getattr(method, "__webmethods__", [])
|
webmethods = getattr(method, "__webmethods__", [])
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||||
|
|
@ -66,7 +65,6 @@ class LlamaStack(
|
||||||
Agents,
|
Agents,
|
||||||
Batches,
|
Batches,
|
||||||
Safety,
|
Safety,
|
||||||
SyntheticDataGeneration,
|
|
||||||
Datasets,
|
Datasets,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
@ -80,7 +78,6 @@ class LlamaStack(
|
||||||
Inspect,
|
Inspect,
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
RAGToolRuntime,
|
|
||||||
Files,
|
Files,
|
||||||
Prompts,
|
Prompts,
|
||||||
Conversations,
|
Conversations,
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.core.ui.modules.api import llama_stack_api
|
||||||
def models():
|
def models():
|
||||||
# Models Section
|
# Models Section
|
||||||
st.header("Models")
|
st.header("Models")
|
||||||
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
|
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
|
||||||
|
|
||||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||||
st.json(models_info[selected_model])
|
st.json(models_info[selected_model])
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,11 @@ from llama_stack.core.ui.modules.api import llama_stack_api
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.header("Configuration")
|
st.header("Configuration")
|
||||||
available_models = llama_stack_api.client.models.list()
|
available_models = llama_stack_api.client.models.list()
|
||||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
available_models = [
|
||||||
|
model.id
|
||||||
|
for model in available_models
|
||||||
|
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
|
||||||
|
]
|
||||||
selected_model = st.selectbox(
|
selected_model = st.selectbox(
|
||||||
"Choose a model",
|
"Choose a model",
|
||||||
available_models,
|
available_models,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -4,21 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Agent,
|
|
||||||
AgentConfig,
|
|
||||||
AgentCreateResponse,
|
|
||||||
Agents,
|
Agents,
|
||||||
AgentSessionCreateResponse,
|
|
||||||
AgentStepResponse,
|
|
||||||
AgentToolGroup,
|
|
||||||
AgentTurnCreateRequest,
|
|
||||||
AgentTurnResumeRequest,
|
|
||||||
Document,
|
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
OpenAIDeleteResponseObject,
|
OpenAIDeleteResponseObject,
|
||||||
|
|
@ -26,19 +14,12 @@ from llama_stack.apis.agents import (
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
Order,
|
Order,
|
||||||
Session,
|
|
||||||
Turn,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||||
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
|
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
|
||||||
from llama_stack.apis.conversations import Conversations
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
ToolConfig,
|
|
||||||
ToolResponse,
|
|
||||||
ToolResponseMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
|
@ -46,12 +27,9 @@ from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
@ -97,229 +75,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
conversations_api=self.conversations_api,
|
conversations_api=self.conversations_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent(
|
|
||||||
self,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
) -> AgentCreateResponse:
|
|
||||||
agent_id = str(uuid.uuid4())
|
|
||||||
created_at = datetime.now(UTC)
|
|
||||||
|
|
||||||
agent_info = AgentInfo(
|
|
||||||
**agent_config.model_dump(),
|
|
||||||
created_at=created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the agent info
|
|
||||||
await self.persistence_store.set(
|
|
||||||
key=f"agent:{agent_id}",
|
|
||||||
value=agent_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentCreateResponse(
|
|
||||||
agent_id=agent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
|
||||||
agent_info_json = await self.persistence_store.get(
|
|
||||||
key=f"agent:{agent_id}",
|
|
||||||
)
|
|
||||||
if not agent_info_json:
|
|
||||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
|
||||||
|
|
||||||
return ChatAgent(
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_config=agent_info,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
safety_api=self.safety_api,
|
|
||||||
vector_io_api=self.vector_io_api,
|
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
|
||||||
tool_groups_api=self.tool_groups_api,
|
|
||||||
persistence_store=(
|
|
||||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
|
||||||
),
|
|
||||||
created_at=agent_info.created_at.isoformat(),
|
|
||||||
policy=self.policy,
|
|
||||||
telemetry_enabled=self.telemetry_enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_agent_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_name: str,
|
|
||||||
) -> AgentSessionCreateResponse:
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
|
|
||||||
session_id = await agent.create_session(session_name)
|
|
||||||
return AgentSessionCreateResponse(
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
messages: list[UserMessage | ToolResponseMessage],
|
|
||||||
stream: bool | None = False,
|
|
||||||
documents: list[Document] | None = None,
|
|
||||||
toolgroups: list[AgentToolGroup] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
request = AgentTurnCreateRequest(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
toolgroups=toolgroups,
|
|
||||||
documents=documents,
|
|
||||||
tool_config=tool_config,
|
|
||||||
)
|
|
||||||
if stream:
|
|
||||||
return self._create_agent_turn_streaming(request)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
||||||
|
|
||||||
async def _create_agent_turn_streaming(
|
|
||||||
self,
|
|
||||||
request: AgentTurnCreateRequest,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
agent = await self._get_agent_impl(request.agent_id)
|
|
||||||
async for event in agent.create_and_execute_turn(request):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
async def resume_agent_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
tool_responses: list[ToolResponse],
|
|
||||||
stream: bool | None = False,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
request = AgentTurnResumeRequest(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
tool_responses=tool_responses,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
if stream:
|
|
||||||
return self._continue_agent_turn_streaming(request)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
||||||
|
|
||||||
async def _continue_agent_turn_streaming(
|
|
||||||
self,
|
|
||||||
request: AgentTurnResumeRequest,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
agent = await self._get_agent_impl(request.agent_id)
|
|
||||||
async for event in agent.resume_turn(request):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
|
||||||
if turn is None:
|
|
||||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
|
||||||
return turn
|
|
||||||
|
|
||||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
|
||||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
|
||||||
for step in turn.steps:
|
|
||||||
if step.step_id == step_id:
|
|
||||||
return AgentStepResponse(step=step)
|
|
||||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
|
||||||
|
|
||||||
async def get_agents_session(
|
|
||||||
self,
|
|
||||||
session_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
turn_ids: list[str] | None = None,
|
|
||||||
) -> Session:
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
|
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
turns = await agent.storage.get_session_turns(session_id)
|
|
||||||
if turn_ids:
|
|
||||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
|
||||||
return Session(
|
|
||||||
session_name=session_info.session_name,
|
|
||||||
session_id=session_id,
|
|
||||||
turns=turns,
|
|
||||||
started_at=session_info.started_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
|
|
||||||
# Delete turns first, then the session
|
|
||||||
await agent.storage.delete_session_turns(session_id)
|
|
||||||
await agent.storage.delete_session(session_id)
|
|
||||||
|
|
||||||
async def delete_agent(self, agent_id: str) -> None:
|
|
||||||
# First get all sessions for this agent
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
sessions = await agent.storage.list_sessions()
|
|
||||||
|
|
||||||
# Delete all sessions
|
|
||||||
for session in sessions:
|
|
||||||
await self.delete_agents_session(agent_id, session.session_id)
|
|
||||||
|
|
||||||
# Finally delete the agent itself
|
|
||||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
|
||||||
|
|
||||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
|
||||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
|
||||||
agent_list: list[Agent] = []
|
|
||||||
for agent_key in agent_keys:
|
|
||||||
agent_id = agent_key.split(":")[1]
|
|
||||||
|
|
||||||
# Get the agent info using the key
|
|
||||||
agent_info_json = await self.persistence_store.get(agent_key)
|
|
||||||
if not agent_info_json:
|
|
||||||
logger.error(f"Could not find agent info for key {agent_key}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
|
||||||
agent_list.append(
|
|
||||||
Agent(
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_config=agent_info,
|
|
||||||
created_at=agent_info.created_at,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert Agent objects to dictionaries
|
|
||||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
|
||||||
return paginate_records(agent_dicts, start_index, limit)
|
|
||||||
|
|
||||||
async def get_agent(self, agent_id: str) -> Agent:
|
|
||||||
chat_agent = await self._get_agent_impl(agent_id)
|
|
||||||
agent = Agent(
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_config=chat_agent.agent_config,
|
|
||||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
|
||||||
)
|
|
||||||
return agent
|
|
||||||
|
|
||||||
async def list_agent_sessions(
|
|
||||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
agent = await self._get_agent_impl(agent_id)
|
|
||||||
sessions = await agent.storage.list_sessions()
|
|
||||||
# Convert Session objects to dictionaries
|
|
||||||
session_dicts = [session.model_dump() for session in sessions]
|
|
||||||
return paginate_records(session_dicts, start_index, limit)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,261 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
|
||||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
|
||||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
|
||||||
from llama_stack.core.datatypes import User
|
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
|
||||||
# TODO: is this used anywhere?
|
|
||||||
vector_store_id: str | None = None
|
|
||||||
started_at: datetime
|
|
||||||
owner: User | None = None
|
|
||||||
identifier: str | None = None
|
|
||||||
type: str = "session"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentInfo(AgentConfig):
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SessionResource:
|
|
||||||
"""Concrete implementation of ProtectedResource for session access control."""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
identifier: str
|
|
||||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
|
||||||
self.agent_id = agent_id
|
|
||||||
self.kvstore = kvstore
|
|
||||||
self.policy = policy
|
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Get current user's auth attributes for new sessions
|
|
||||||
user = get_authenticated_user()
|
|
||||||
|
|
||||||
session_info = AgentSessionInfo(
|
|
||||||
session_id=session_id,
|
|
||||||
session_name=name,
|
|
||||||
started_at=datetime.now(UTC),
|
|
||||||
owner=user,
|
|
||||||
turns=[],
|
|
||||||
identifier=name, # should this be qualified in any way?
|
|
||||||
)
|
|
||||||
# Only perform access control if we have an authenticated user
|
|
||||||
if user is not None and session_info.identifier is not None:
|
|
||||||
resource = SessionResource(
|
|
||||||
type=session_info.type,
|
|
||||||
identifier=session_info.identifier,
|
|
||||||
owner=user,
|
|
||||||
)
|
|
||||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
|
||||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
|
||||||
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
value=session_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
|
||||||
value = await self.kvstore.get(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
)
|
|
||||||
if not value:
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
session_info = AgentSessionInfo(**json.loads(value))
|
|
||||||
|
|
||||||
# Check access to session
|
|
||||||
if not self._check_session_access(session_info):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return session_info
|
|
||||||
|
|
||||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
|
||||||
"""Check if current user has access to the session."""
|
|
||||||
# Handle backward compatibility for old sessions without access control
|
|
||||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Get current user - if None, skip access control (e.g., in tests)
|
|
||||||
user = get_authenticated_user()
|
|
||||||
if user is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Access control requires identifier and owner to be set
|
|
||||||
if session_info.identifier is None or session_info.owner is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# At this point, both identifier and owner are guaranteed to be non-None
|
|
||||||
resource = SessionResource(
|
|
||||||
type=session_info.type,
|
|
||||||
identifier=session_info.identifier,
|
|
||||||
owner=session_info.owner,
|
|
||||||
)
|
|
||||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
if not session_info:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return session_info
|
|
||||||
|
|
||||||
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
|
|
||||||
session_info = await self.get_session_if_accessible(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
session_info.vector_store_id = vector_store_id
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
value=session_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
|
||||||
value=turn.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
values = await self.kvstore.values_in_range(
|
|
||||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
|
||||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
|
||||||
)
|
|
||||||
turns = []
|
|
||||||
for value in values:
|
|
||||||
try:
|
|
||||||
turn = Turn(**json.loads(value))
|
|
||||||
turns.append(turn)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error parsing turn: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The kvstore does not guarantee order, so we sort by started_at
|
|
||||||
# to ensure consistent ordering of turns.
|
|
||||||
turns.sort(key=lambda t: t.started_at)
|
|
||||||
|
|
||||||
return turns
|
|
||||||
|
|
||||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
|
||||||
)
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
return Turn(**json.loads(value))
|
|
||||||
|
|
||||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
|
||||||
value=step.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
return None
|
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
|
||||||
)
|
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
|
||||||
|
|
||||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
|
||||||
value=str(num_infer_iters),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
|
||||||
if not await self.get_session_if_accessible(session_id):
|
|
||||||
return None
|
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
|
||||||
)
|
|
||||||
return int(value) if value else None
|
|
||||||
|
|
||||||
async def list_sessions(self) -> list[Session]:
|
|
||||||
values = await self.kvstore.values_in_range(
|
|
||||||
start_key=f"session:{self.agent_id}:",
|
|
||||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
|
||||||
)
|
|
||||||
sessions = []
|
|
||||||
for value in values:
|
|
||||||
try:
|
|
||||||
data = json.loads(value)
|
|
||||||
if "turn_id" in data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
session_info = Session(**data)
|
|
||||||
sessions.append(session_info)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error parsing session info: {e}")
|
|
||||||
continue
|
|
||||||
return sessions
|
|
||||||
|
|
||||||
async def delete_session_turns(self, session_id: str) -> None:
|
|
||||||
"""Delete all turns and their associated data for a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The ID of the session whose turns should be deleted.
|
|
||||||
"""
|
|
||||||
turns = await self.get_session_turns(session_id)
|
|
||||||
for turn in turns:
|
|
||||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
|
||||||
|
|
||||||
async def delete_session(self, session_id: str) -> None:
|
|
||||||
"""Delete a session and all its associated turns.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The ID of the session to delete.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the session does not exist.
|
|
||||||
"""
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise SessionNotFoundError(session_id)
|
|
||||||
|
|
||||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents, StepType
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmark
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
|
@ -18,13 +18,9 @@ from llama_stack.apis.inference import (
|
||||||
OpenAICompletionRequestWithExtraBody,
|
OpenAICompletionRequestWithExtraBody,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|
||||||
MEMORY_QUERY_TOOL,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
|
@ -118,49 +114,6 @@ class MetaReferenceEvalImpl(
|
||||||
self.jobs[job_id] = res
|
self.jobs[job_id] = res
|
||||||
return Job(job_id=job_id, status=JobStatus.completed)
|
return Job(job_id=job_id, status=JobStatus.completed)
|
||||||
|
|
||||||
async def _run_agent_generation(
|
|
||||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
candidate = benchmark_config.eval_candidate
|
|
||||||
create_response = await self.agents_api.create_agent(candidate.config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
generations = []
|
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
|
||||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
|
||||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
|
||||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=input_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
|
||||||
final_event = turn_response[-1].event.payload
|
|
||||||
|
|
||||||
# check if there's a memory retrieval step and extract the context
|
|
||||||
memory_rag_context = None
|
|
||||||
for step in final_event.turn.steps:
|
|
||||||
if step.step_type == StepType.tool_execution.value:
|
|
||||||
for tool_response in step.tool_responses:
|
|
||||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
|
||||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
|
||||||
|
|
||||||
agent_generation = {}
|
|
||||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
|
||||||
if memory_rag_context:
|
|
||||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
|
||||||
|
|
||||||
generations.append(agent_generation)
|
|
||||||
|
|
||||||
return generations
|
|
||||||
|
|
||||||
async def _run_model_generation(
|
async def _run_model_generation(
|
||||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
|
|
@ -215,9 +168,8 @@ class MetaReferenceEvalImpl(
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = benchmark_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
# Agent evaluation removed
|
||||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
if candidate.type == "model":
|
||||||
elif candidate.type == "model":
|
|
||||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,6 @@ from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
|
@ -91,7 +90,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||||
return content_str.encode("utf-8"), "text/plain"
|
return content_str.encode("utf-8"), "text/plain"
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import google.auth.transport.requests
|
import google.auth.transport.requests
|
||||||
from google.auth import default
|
from google.auth import default
|
||||||
|
|
@ -42,3 +43,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
|
||||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||||
"""
|
"""
|
||||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||||
|
|
||||||
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
|
"""
|
||||||
|
VertexAI doesn't currently offer a way to query a list of available models from Google's Model Garden
|
||||||
|
For now we return a hardcoded version of the available models
|
||||||
|
|
||||||
|
:return: An iterable of model IDs
|
||||||
|
"""
|
||||||
|
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ class InferenceStore:
|
||||||
self.reference = reference
|
self.reference = reference
|
||||||
self.sql_store = None
|
self.sql_store = None
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
self.enable_write_queue = True
|
||||||
|
|
||||||
# Async write queue and worker control
|
# Async write queue and worker control
|
||||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||||
|
|
@ -47,14 +48,13 @@ class InferenceStore:
|
||||||
base_store = sqlstore_impl(self.reference)
|
base_store = sqlstore_impl(self.reference)
|
||||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||||
|
|
||||||
# Disable write queue for SQLite to avoid concurrency issues
|
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||||
backend_name = self.reference.backend
|
# Keep it enabled for other backends (like Postgres) for performance
|
||||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||||
if backend_config is None:
|
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||||
raise ValueError(
|
self.enable_write_queue = False
|
||||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||||
)
|
|
||||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
|
@ -70,8 +70,9 @@ class InferenceStore:
|
||||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
for _ in range(self._num_writers):
|
for _ in range(self._num_writers):
|
||||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
else:
|
logger.debug(
|
||||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if not self._worker_tasks:
|
||||||
|
|
|
||||||
|
|
@ -70,13 +70,13 @@ class ResponsesStore:
|
||||||
base_store = sqlstore_impl(self.reference)
|
base_store = sqlstore_impl(self.reference)
|
||||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||||
|
|
||||||
|
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||||
|
# Keep it enabled for other backends (like Postgres) for performance
|
||||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||||
if backend_config is None:
|
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||||
raise ValueError(
|
|
||||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
|
||||||
)
|
|
||||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
|
||||||
self.enable_write_queue = False
|
self.enable_write_queue = False
|
||||||
|
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||||
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
{
|
{
|
||||||
|
|
@ -99,8 +99,9 @@ class ResponsesStore:
|
||||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
for _ in range(self._num_writers):
|
for _ in range(self._num_writers):
|
||||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
else:
|
logger.debug(
|
||||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if not self._worker_tasks:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||||
String,
|
String,
|
||||||
Table,
|
Table,
|
||||||
Text,
|
Text,
|
||||||
|
event,
|
||||||
inspect,
|
inspect,
|
||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
|
|
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
self.metadata = MetaData()
|
self.metadata = MetaData()
|
||||||
|
|
||||||
def create_engine(self) -> AsyncEngine:
|
def create_engine(self) -> AsyncEngine:
|
||||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
# Configure connection args for better concurrency support
|
||||||
|
connect_args = {}
|
||||||
|
if "sqlite" in self.config.engine_str:
|
||||||
|
# SQLite-specific optimizations for concurrent access
|
||||||
|
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||||
|
connect_args["timeout"] = 5.0
|
||||||
|
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
self.config.engine_str,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||||
|
if "sqlite" in self.config.engine_str:
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
# Enable Write-Ahead Logging for better concurrency
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||||
|
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||||
|
cursor.execute("PRAGMA busy_timeout=5000")
|
||||||
|
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
async def create_table(
|
async def create_table(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -156,7 +156,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
if parsed.path not in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||||
normalized["test_id"] = test_id
|
normalized["test_id"] = test_id
|
||||||
|
|
||||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||||
|
|
@ -430,7 +430,7 @@ class ResponseStorage:
|
||||||
|
|
||||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
# For model-list endpoints, include digest in filename to distinguish different model sets
|
||||||
endpoint = request.get("endpoint")
|
endpoint = request.get("endpoint")
|
||||||
if endpoint in ("/api/tags", "/v1/models"):
|
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||||
digest = _model_identifiers_digest(endpoint, response)
|
digest = _model_identifiers_digest(endpoint, response)
|
||||||
response_file = f"models-{request_hash}-{digest}.json"
|
response_file = f"models-{request_hash}-{digest}.json"
|
||||||
|
|
||||||
|
|
@ -554,13 +554,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||||
Supported endpoints:
|
Supported endpoints:
|
||||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||||
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||||
|
- '/v1/openai/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||||
Returns a list of unique identifiers or None if structure doesn't match.
|
Returns a list of unique identifiers or None if structure doesn't match.
|
||||||
"""
|
"""
|
||||||
if "models" in response["body"]:
|
if "models" in response["body"]:
|
||||||
# ollama
|
# ollama
|
||||||
items = response["body"]["models"]
|
items = response["body"]["models"]
|
||||||
else:
|
else:
|
||||||
# openai
|
# openai or openai-style endpoints
|
||||||
items = response["body"]
|
items = response["body"]
|
||||||
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||||
return sorted(set(idents))
|
return sorted(set(idents))
|
||||||
|
|
@ -581,7 +582,7 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
||||||
seen: dict[str, dict[str, Any]] = {}
|
seen: dict[str, dict[str, Any]] = {}
|
||||||
for rec in records:
|
for rec in records:
|
||||||
body = rec["response"]["body"]
|
body = rec["response"]["body"]
|
||||||
if endpoint == "/v1/models":
|
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||||
for m in body:
|
for m in body:
|
||||||
key = m.id
|
key = m.id
|
||||||
seen[key] = m
|
seen[key] = m
|
||||||
|
|
@ -665,7 +666,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
logger.info(f" Test context: {get_test_context()}")
|
logger.info(f" Test context: {get_test_context()}")
|
||||||
|
|
||||||
if mode == APIRecordingMode.LIVE or storage is None:
|
if mode == APIRecordingMode.LIVE or storage is None:
|
||||||
if endpoint == "/v1/models":
|
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||||
return original_method(self, *args, **kwargs)
|
return original_method(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return await original_method(self, *args, **kwargs)
|
return await original_method(self, *args, **kwargs)
|
||||||
|
|
@ -699,7 +700,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
recording = None
|
recording = None
|
||||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||||
if endpoint in ("/api/tags", "/v1/models"):
|
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||||
records = storage._model_list_responses(request_hash)
|
records = storage._model_list_responses(request_hash)
|
||||||
recording = _combine_model_list_responses(endpoint, records)
|
recording = _combine_model_list_responses(endpoint, records)
|
||||||
else:
|
else:
|
||||||
|
|
@ -739,13 +740,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||||
if endpoint == "/v1/models":
|
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||||
response = original_method(self, *args, **kwargs)
|
response = original_method(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
response = await original_method(self, *args, **kwargs)
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
# we want to store the result of the iterator, not the iterator itself
|
# we want to store the result of the iterator, not the iterator itself
|
||||||
if endpoint == "/v1/models":
|
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||||
response = [m async for m in response]
|
response = [m async for m in response]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue