Merge branch 'main' into fix/list-command-show-builtin-distros

This commit is contained in:
Roy Belio 2025-11-05 11:29:00 +02:00 committed by GitHub
commit f80c321604
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1202 changed files with 942 additions and 324667 deletions

View file

@ -5,7 +5,7 @@ omit =
*/llama_stack/templates/*
.venv/*
*/llama_stack/cli/scripts/*
*/llama_stack/ui/*
*/llama_stack_ui/*
*/llama_stack/distribution/ui/*
*/llama_stack/strong_typing/*
*/llama_stack/env.py

View file

@ -22,7 +22,7 @@ updates:
prefix: chore(python-deps)
- package-ecosystem: npm
directory: "/llama_stack/ui"
directory: "/llama_stack_ui"
schedule:
interval: "weekly"
day: "saturday"

View file

@ -14,7 +14,7 @@ on:
paths:
- 'distributions/**'
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'
- '!src/llama_stack_ui/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'

View file

@ -14,7 +14,7 @@ on:
types: [opened, synchronize, reopened]
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'
- '!src/llama_stack_ui/**'
- 'tests/**'
- 'uv.lock'
- 'pyproject.toml'

View file

@ -13,7 +13,7 @@ on:
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'
- '!src/llama_stack_ui/**'
- 'tests/integration/vector_io/**'
- 'uv.lock'
- 'pyproject.toml'

View file

@ -43,14 +43,14 @@ jobs:
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'src/llama_stack/ui/'
cache-dependency-path: 'src/llama_stack_ui/'
- name: Set up uv
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
- name: Install npm dependencies
run: npm ci
working-directory: src/llama_stack/ui
working-directory: src/llama_stack_ui
- name: Install pre-commit
run: python -m pip install pre-commit

View file

@ -10,7 +10,7 @@ on:
branches:
- main
paths-ignore:
- 'src/llama_stack/ui/**'
- 'src/llama_stack_ui/**'
jobs:
build:

View file

@ -9,7 +9,7 @@ on:
branches: [ main ]
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'
- '!src/llama_stack_ui/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'

View file

@ -8,7 +8,7 @@ on:
pull_request:
branches: [ main ]
paths:
- 'src/llama_stack/ui/**'
- 'src/llama_stack_ui/**'
- '.github/workflows/ui-unit-tests.yml' # This workflow
workflow_dispatch:
@ -33,22 +33,22 @@ jobs:
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: 'src/llama_stack/ui/package-lock.json'
cache-dependency-path: 'src/llama_stack_ui/package-lock.json'
- name: Install dependencies
working-directory: src/llama_stack/ui
working-directory: src/llama_stack_ui
run: npm ci
- name: Run linting
working-directory: src/llama_stack/ui
working-directory: src/llama_stack_ui
run: npm run lint
- name: Run format check
working-directory: src/llama_stack/ui
working-directory: src/llama_stack_ui
run: npm run format:check
- name: Run unit tests
working-directory: src/llama_stack/ui
working-directory: src/llama_stack_ui
env:
CI: true

View file

@ -13,7 +13,7 @@ on:
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'
- '!src/llama_stack_ui/**'
- 'tests/unit/**'
- 'uv.lock'
- 'pyproject.toml'

View file

@ -161,7 +161,7 @@ repos:
name: Format & Lint UI
entry: bash ./scripts/run-ui-linter.sh
language: system
files: ^src/llama_stack/ui/.*\.(ts|tsx)$
files: ^src/llama_stack_ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true

View file

@ -0,0 +1,525 @@
# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json
organization:
# Name of your organization or company, used to determine the name of the client
# and headings.
name: llama-stack-client
docs: https://llama-stack.readthedocs.io/en/latest/
contact: llamastack@meta.com
security:
- {}
- BearerAuth: []
security_schemes:
BearerAuth:
type: http
scheme: bearer
# `targets` define the output targets and their customization options, such as
# whether to emit the Node SDK and what it's package name should be.
targets:
node:
package_name: llama-stack-client
production_repo: llamastack/llama-stack-client-typescript
publish:
npm: false
python:
package_name: llama_stack_client
production_repo: llamastack/llama-stack-client-python
options:
use_uv: true
publish:
pypi: true
project_name: llama_stack_client
kotlin:
reverse_domain: com.llama_stack_client.api
production_repo: null
publish:
maven: false
go:
package_name: llama-stack-client
production_repo: llamastack/llama-stack-client-go
options:
enable_v2: true
back_compat_use_shared_package: false
# `client_settings` define settings for the API client, such as extra constructor
# arguments (used for authentication), retry behavior, idempotency, etc.
client_settings:
default_env_prefix: LLAMA_STACK_CLIENT
opts:
api_key:
type: string
read_env: LLAMA_STACK_CLIENT_API_KEY
auth: { security_scheme: BearerAuth }
nullable: true
# `environments` are a map of the name of the environment (e.g. "sandbox",
# "production") to the corresponding url to use.
environments:
production: http://any-hosted-llama-stack.com
# `pagination` defines [pagination schemes] which provides a template to match
# endpoints and generate next-page and auto-pagination helpers in the SDKs.
pagination:
- name: datasets_iterrows
type: offset
request:
dataset_id:
type: string
start_index:
type: integer
x-stainless-pagination-property:
purpose: offset_count_param
limit:
type: integer
response:
data:
type: array
items:
type: object
next_index:
type: integer
x-stainless-pagination-property:
purpose: offset_count_start_field
- name: openai_cursor_page
type: cursor
request:
limit:
type: integer
after:
type: string
x-stainless-pagination-property:
purpose: next_cursor_param
response:
data:
type: array
items: {}
has_more:
type: boolean
last_id:
type: string
x-stainless-pagination-property:
purpose: next_cursor_field
# `resources` define the structure and organziation for your API, such as how
# methods and models are grouped together and accessed. See the [configuration
# guide] for more information.
#
# [configuration guide]:
# https://app.stainlessapi.com/docs/guides/configure#resources
resources:
$shared:
models:
interleaved_content_item: InterleavedContentItem
interleaved_content: InterleavedContent
param_type: ParamType
safety_violation: SafetyViolation
sampling_params: SamplingParams
scoring_result: ScoringResult
system_message: SystemMessage
query_result: RAGQueryResult
document: RAGDocument
query_config: RAGQueryConfig
toolgroups:
models:
tool_group: ToolGroup
list_tool_groups_response: ListToolGroupsResponse
methods:
register: post /v1/toolgroups
get: get /v1/toolgroups/{toolgroup_id}
list: get /v1/toolgroups
unregister: delete /v1/toolgroups/{toolgroup_id}
tools:
methods:
get: get /v1/tools/{tool_name}
list:
endpoint: get /v1/tools
paginated: false
tool_runtime:
models:
tool_def: ToolDef
tool_invocation_result: ToolInvocationResult
methods:
list_tools:
endpoint: get /v1/tool-runtime/list-tools
paginated: false
invoke_tool: post /v1/tool-runtime/invoke
subresources:
rag_tool:
methods:
insert: post /v1/tool-runtime/rag-tool/insert
query: post /v1/tool-runtime/rag-tool/query
responses:
models:
response_object_stream: OpenAIResponseObjectStream
response_object: OpenAIResponseObject
methods:
create:
type: http
endpoint: post /v1/responses
streaming:
stream_event_model: responses.response_object_stream
param_discriminator: stream
retrieve: get /v1/responses/{response_id}
list:
type: http
endpoint: get /v1/responses
delete:
type: http
endpoint: delete /v1/responses/{response_id}
subresources:
input_items:
methods:
list:
type: http
endpoint: get /v1/responses/{response_id}/input_items
prompts:
models:
prompt: Prompt
list_prompts_response: ListPromptsResponse
methods:
create: post /v1/prompts
list:
endpoint: get /v1/prompts
paginated: false
retrieve: get /v1/prompts/{prompt_id}
update: post /v1/prompts/{prompt_id}
delete: delete /v1/prompts/{prompt_id}
set_default_version: post /v1/prompts/{prompt_id}/set-default-version
subresources:
versions:
methods:
list:
endpoint: get /v1/prompts/{prompt_id}/versions
paginated: false
conversations:
models:
conversation_object: Conversation
methods:
create:
type: http
endpoint: post /v1/conversations
retrieve: get /v1/conversations/{conversation_id}
update:
type: http
endpoint: post /v1/conversations/{conversation_id}
delete:
type: http
endpoint: delete /v1/conversations/{conversation_id}
subresources:
items:
methods:
get:
type: http
endpoint: get /v1/conversations/{conversation_id}/items/{item_id}
list:
type: http
endpoint: get /v1/conversations/{conversation_id}/items
create:
type: http
endpoint: post /v1/conversations/{conversation_id}/items
inspect:
models:
healthInfo: HealthInfo
providerInfo: ProviderInfo
routeInfo: RouteInfo
versionInfo: VersionInfo
methods:
health: get /v1/health
version: get /v1/version
embeddings:
models:
create_embeddings_response: OpenAIEmbeddingsResponse
methods:
create: post /v1/embeddings
chat:
models:
chat_completion_chunk: OpenAIChatCompletionChunk
subresources:
completions:
methods:
create:
type: http
endpoint: post /v1/chat/completions
streaming:
stream_event_model: chat.chat_completion_chunk
param_discriminator: stream
list:
type: http
endpoint: get /v1/chat/completions
retrieve:
type: http
endpoint: get /v1/chat/completions/{completion_id}
completions:
methods:
create:
type: http
endpoint: post /v1/completions
streaming:
param_discriminator: stream
vector_io:
models:
queryChunksResponse: QueryChunksResponse
methods:
insert: post /v1/vector-io/insert
query: post /v1/vector-io/query
vector_stores:
models:
vector_store: VectorStoreObject
list_vector_stores_response: VectorStoreListResponse
vector_store_delete_response: VectorStoreDeleteResponse
vector_store_search_response: VectorStoreSearchResponsePage
methods:
create: post /v1/vector_stores
list:
endpoint: get /v1/vector_stores
retrieve: get /v1/vector_stores/{vector_store_id}
update: post /v1/vector_stores/{vector_store_id}
delete: delete /v1/vector_stores/{vector_store_id}
search: post /v1/vector_stores/{vector_store_id}/search
subresources:
files:
models:
vector_store_file: VectorStoreFileObject
methods:
list: get /v1/vector_stores/{vector_store_id}/files
retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id}
update: post /v1/vector_stores/{vector_store_id}/files/{file_id}
delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id}
create: post /v1/vector_stores/{vector_store_id}/files
content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content
file_batches:
models:
vector_store_file_batches: VectorStoreFileBatchObject
list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse
methods:
create: post /v1/vector_stores/{vector_store_id}/file_batches
retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
models:
models:
model: OpenAIModel
list_models_response: OpenAIListModelsResponse
methods:
list:
endpoint: get /v1/models
paginated: false
retrieve: get /v1/models/{model_id}
register: post /v1/models
unregister: delete /v1/models/{model_id}
subresources:
openai:
methods:
list:
endpoint: get /v1/models
paginated: false
providers:
models:
list_providers_response: ListProvidersResponse
methods:
list:
endpoint: get /v1/providers
paginated: false
retrieve: get /v1/providers/{provider_id}
routes:
models:
list_routes_response: ListRoutesResponse
methods:
list:
endpoint: get /v1/inspect/routes
paginated: false
moderations:
models:
create_response: ModerationObject
methods:
create: post /v1/moderations
safety:
models:
run_shield_response: RunShieldResponse
methods:
run_shield: post /v1/safety/run-shield
shields:
models:
shield: Shield
list_shields_response: ListShieldsResponse
methods:
retrieve: get /v1/shields/{identifier}
list:
endpoint: get /v1/shields
paginated: false
register: post /v1/shields
delete: delete /v1/shields/{identifier}
scoring:
methods:
score: post /v1/scoring/score
score_batch: post /v1/scoring/score-batch
scoring_functions:
methods:
retrieve: get /v1/scoring-functions/{scoring_fn_id}
list:
endpoint: get /v1/scoring-functions
paginated: false
register: post /v1/scoring-functions
models:
scoring_fn: ScoringFn
scoring_fn_params: ScoringFnParams
list_scoring_functions_response: ListScoringFunctionsResponse
files:
methods:
create: post /v1/files
list: get /v1/files
retrieve: get /v1/files/{file_id}
delete: delete /v1/files/{file_id}
content: get /v1/files/{file_id}/content
models:
file: OpenAIFileObject
list_files_response: ListOpenAIFileResponse
delete_file_response: OpenAIFileDeleteResponse
alpha:
subresources:
inference:
methods:
rerank: post /v1alpha/inference/rerank
post_training:
models:
algorithm_config: AlgorithmConfig
post_training_job: PostTrainingJob
list_post_training_jobs_response: ListPostTrainingJobsResponse
methods:
preference_optimize: post /v1alpha/post-training/preference-optimize
supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune
subresources:
job:
methods:
artifacts: get /v1alpha/post-training/job/artifacts
cancel: post /v1alpha/post-training/job/cancel
status: get /v1alpha/post-training/job/status
list:
endpoint: get /v1alpha/post-training/jobs
paginated: false
benchmarks:
methods:
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}
list:
endpoint: get /v1alpha/eval/benchmarks
paginated: false
register: post /v1alpha/eval/benchmarks
models:
benchmark: Benchmark
list_benchmarks_response: ListBenchmarksResponse
eval:
methods:
evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
subresources:
jobs:
methods:
cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result
models:
evaluate_response: EvaluateResponse
benchmark_config: BenchmarkConfig
job: Job
beta:
subresources:
datasets:
models:
list_datasets_response: ListDatasetsResponse
methods:
register: post /v1beta/datasets
retrieve: get /v1beta/datasets/{dataset_id}
list:
endpoint: get /v1beta/datasets
paginated: false
unregister: delete /v1beta/datasets/{dataset_id}
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
settings:
license: MIT
unwrap_response_fields: [ data ]
openapi:
transformations:
- command: mergeObject
reason: Better return_type using enum
args:
target:
- '$.components.schemas'
object:
ReturnType:
additionalProperties: false
properties:
type:
enum:
- string
- number
- boolean
- array
- object
- json
- union
- chat_completion_input
- completion_input
- agent_turn_input
required:
- type
type: object
- command: replaceProperties
reason: Replace return type properties with better model (see above)
args:
filter:
only:
- '$.components.schemas.ScoringFn.properties.return_type'
- '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type'
value:
$ref: '#/components/schemas/ReturnType'
- command: oneOfToAnyOf
reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants
# `readme` is used to configure the code snippets that will be rendered in the
# README.md of various SDKs. In particular, you can change the `headline`
# snippet's endpoint and the arguments to call it with.
readme:
example_requests:
default:
type: request
endpoint: post /v1/chat/completions
params: &ref_0 {}
headline:
type: request
endpoint: post /v1/models
params: *ref_0
pagination:
type: request
endpoint: post /v1/chat/completions
params: {}

File diff suppressed because it is too large Load diff

View file

@ -44,7 +44,7 @@ spec:
# Navigate to the UI directory
echo "Navigating to UI directory..."
cd /app/llama_stack/ui
cd /app/llama_stack_ui
# Check if package.json exists
if [ ! -f "package.json" ]; then

View file

@ -170,7 +170,7 @@ def _get_endpoint_functions(
for webmethod in webmethods:
print(f"Processing {colored(func_name, 'white')}...")
operation_name = func_name
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
@ -196,16 +196,10 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls
raise ValidationError(

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -2052,69 +2052,6 @@ paths:
schema:
$ref: '#/components/schemas/URL'
deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups:
get:
responses:
@ -8137,20 +8074,6 @@ components:
- error
title: ViolationLevel
description: Severity level of a safety violation.
AgentTurnInputType:
type: object
properties:
type:
type: string
const: agent_turn_input
default: agent_turn_input
description: >-
Discriminator type. Always "agent_turn_input"
additionalProperties: false
required:
- type
title: AgentTurnInputType
description: Parameter type for agent turn input.
AggregationFunctionType:
type: string
enum:
@ -8393,7 +8316,6 @@ components:
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
- $ref: '#/components/schemas/AgentTurnInputType'
discriminator:
propertyName: type
mapping:
@ -8406,7 +8328,6 @@ components:
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
agent_turn_input: '#/components/schemas/AgentTurnInputType'
params:
$ref: '#/components/schemas/ScoringFnParams'
additionalProperties: false
@ -8487,7 +8408,6 @@ components:
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
- $ref: '#/components/schemas/AgentTurnInputType'
discriminator:
propertyName: type
mapping:
@ -8500,7 +8420,6 @@ components:
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
agent_turn_input: '#/components/schemas/AgentTurnInputType'
RegisterScoringFunctionRequest:
type: object
properties:
@ -8935,274 +8854,6 @@ components:
title: ListToolDefsResponse
description: >-
Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_store_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_store_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_store_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_store_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup:
type: object
properties:

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,7 @@
# the root directory of this source tree.
set -e
cd src/llama_stack/ui
cd src/llama_stack_ui
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
echo "UI dependencies not installed, skipping prettier/linter check"

View file

@ -5,30 +5,13 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from typing import Annotated, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
@ -57,658 +40,12 @@ class ResponseGuardrailSpec(BaseModel):
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_store_ids: str
inserted_context: InterleavedContent
Step = Annotated[
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System.
:param turn_id: Unique identifier for the turn within a session
:param session_id: Unique identifier for the conversation session
:param input_messages: List of messages that initiated this turn
:param steps: Ordered list of processing steps executed during this turn
:param output_message: The model's generated response containing content and metadata
:param output_attachments: (Optional) Files or media attached to the agent's response
:param started_at: Timestamp when the turn began
:param completed_at: (Optional) Timestamp when the turn finished, if completed
"""
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System.
:param session_id: Unique identifier for the conversation session
:param session_name: Human-readable name for the session
:param turns: List of all turns that have occurred in this session
:param started_at: Timestamp when the session was created
"""
session_id: str
session_name: str
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: dict[str, Any]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent.
:param model: The model identifier to use for the agent
:param instructions: The system instructions for the agent
:param name: Optional name for the agent, used in telemetry and identification
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
:param response_format: Optional response format configuration
"""
model: str
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
class Agent(BaseModel):
"""An agent instance with configuration and metadata.
:param agent_id: Unique identifier for the agent
:param agent_config: Configuration settings for the agent
:param created_at: Timestamp when the agent was created
"""
agent_id: str
agent_config: AgentConfig
created_at: datetime
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
"""Payload for step start events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param metadata: (Optional) Additional metadata for the step
"""
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
"""Payload for step completion events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param step_details: Complete details of the executed step
"""
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@json_schema_type
class AgentTurnResponseStepProgressPayload(BaseModel):
"""Payload for step progress events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param delta: Incremental content changes during step execution
"""
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
delta: ContentDelta
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
"""Payload for turn start events in agent turn responses.
:param event_type: Type of event being reported
:param turn_id: Unique identifier for the turn within a session
"""
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
"""Payload for turn completion events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Complete turn data including all steps and results
"""
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
"""Payload for turn awaiting input events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Turn data when waiting for external tool responses
"""
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
AgentTurnResponseEventPayload = Annotated[
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
class AgentTurnResponseEvent(BaseModel):
"""An event in an agent turn response stream.
:param payload: Event-specific payload containing event data
"""
payload: AgentTurnResponseEventPayload
@json_schema_type
class AgentCreateResponse(BaseModel):
"""Response returned when creating a new agent.
:param agent_id: Unique identifier for the created agent
"""
agent_id: str
@json_schema_type
class AgentSessionCreateResponse(BaseModel):
"""Response returned when creating a new agent session.
:param session_id: Unique identifier for the created session
"""
session_id: str
@json_schema_type
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
"""Request to create a new turn for an agent.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param messages: List of messages to start the turn with
:param documents: (Optional) List of documents to provide to the agent
:param toolgroups: (Optional) List of tool groups to make available for this turn
:param stream: (Optional) Whether to stream the response
:param tool_config: (Optional) Tool configuration to override agent defaults
"""
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
"""Request to resume an agent turn with tool responses.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param turn_id: Unique identifier for the turn within a session
:param tool_responses: List of tool responses to submit to continue the turn
:param stream: (Optional) Whether to stream the response
"""
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""Streamed agent turn completion response.
:param event: Individual event in the agent turn response stream
"""
event: AgentTurnResponseEvent
@json_schema_type
class AgentStepResponse(BaseModel):
"""Response containing details of a specific agent step.
:param step: The complete step data and execution details
"""
step: Step
@runtime_checkable
class Agents(Protocol):
"""Agents
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
level=LLAMA_STACK_API_V1ALPHA,
)
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
# We situate the OpenAI Responses API in the Agents API just like we did things
# for Inference. The Responses API, in its intent, serves the same purpose as
# the Agents API above -- it is essentially a lightweight "agentic loop" with

View file

@ -56,14 +56,6 @@ class ToolGroupNotFoundError(ResourceNotFoundError):
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
class SessionNotFoundError(ValueError):
"""raised when Llama Stack cannot find a referenced session or access is denied"""
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""

View file

@ -103,17 +103,6 @@ class CompletionInputType(BaseModel):
type: Literal["completion_input"] = "completion_input"
@json_schema_type
class AgentTurnInputType(BaseModel):
"""Parameter type for agent turn input.
:param type: Discriminator type. Always "agent_turn_input"
"""
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
@json_schema_type
class DialogType(BaseModel):
"""Parameter type for dialog data with semantic output labels.
@ -135,8 +124,7 @@ ParamType = Annotated[
| JsonType
| UnionType
| ChatCompletionInputType
| CompletionInputType
| AgentTurnInputType,
| CompletionInputType,
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")

View file

@ -4,17 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Literal, Protocol
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
@ -32,19 +31,7 @@ class ModelCandidate(BaseModel):
system_message: SystemMessage | None = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
EvalCandidate = ModelCandidate
@json_schema_type

View file

@ -5,18 +5,13 @@
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
@ -30,7 +25,6 @@ class RRFRanker(BaseModel):
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
@ -55,10 +49,8 @@ Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
@ -75,7 +67,6 @@ class RAGDocument(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
@ -87,7 +78,6 @@ class RAGQueryResult(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
@ -101,7 +91,6 @@ class RAGQueryGenerator(Enum):
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
@ -115,7 +104,6 @@ class RAGSearchMode(StrEnum):
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
@ -127,7 +115,6 @@ class DefaultRAGQueryGeneratorConfig(BaseModel):
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
@ -145,10 +132,8 @@ RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
@ -181,38 +166,3 @@ class RAGQueryConfig(BaseModel):
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_store_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_store_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -16,8 +16,6 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolDef(BaseModel):
@ -195,8 +193,6 @@ class SpecialToolGroup(Enum):
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(

View file

@ -253,7 +253,7 @@ class StackRun(Subcommand):
)
return
ui_dir = REPO_ROOT / "llama_stack" / "ui"
ui_dir = REPO_ROOT / "llama_stack_ui"
logs_dir = Path("~/.llama/ui/logs").expanduser()
try:
# Create logs directory if it doesn't exist

View file

@ -8,14 +8,9 @@ from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
@ -26,36 +21,6 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
@ -63,11 +28,6 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass

View file

@ -13,7 +13,6 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])

View file

@ -31,7 +31,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
@ -78,7 +78,6 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
Conversations,

View file

@ -4,21 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
@ -26,19 +14,12 @@ from llama_stack.apis.agents import (
OpenAIResponseInputTool,
OpenAIResponseObject,
Order,
Session,
Turn,
)
from llama_stack.apis.agents.agents import ResponseGuardrail
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -46,12 +27,9 @@ from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -97,229 +75,6 @@ class MetaReferenceAgentsImpl(Agents):
conversations_api=self.conversations_api,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(UTC)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
for step in turn.steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
return Session(
session_name=session_info.session_name,
session_id=session_id,
turns=turns,
started_at=session_info.started_at,
)
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
async def list_agent_sessions(
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass

View file

@ -1,261 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_store_id: str | None = None
started_at: datetime
owner: User | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig):
created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
user = get_authenticated_user()
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(UTC),
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?
)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
raise SessionNotFoundError(session_id)
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
session_info.vector_store_id = vector_store_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -8,7 +8,7 @@ from typing import Any
from tqdm import tqdm
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -18,13 +18,9 @@ from llama_stack.apis.inference import (
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -118,49 +114,6 @@ class MetaReferenceEvalImpl(
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations
async def _run_model_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
@ -215,9 +168,8 @@ class MetaReferenceEvalImpl(
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
# Agent evaluation removed
if candidate.type == "model":
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")

View file

@ -27,7 +27,6 @@ from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
@ -91,7 +90,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,

Some files were not shown because too many files have changed in this diff Show more