Merge branch 'main' into fix-tool-call-args

This commit is contained in:
Ashwin Bharambe 2025-09-30 14:59:22 -07:00 committed by GitHub
commit cbc1b6889e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
89 changed files with 14920 additions and 2301 deletions

View file

@ -27,7 +27,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import (
@ -482,7 +482,10 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST", descriptive_name="create_agent", level=LLAMA_STACK_API_V1)
@webmethod(
route="/agents", method="POST", descriptive_name="create_agent", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/agents", method="POST", descriptive_name="create_agent", level=LLAMA_STACK_API_V1ALPHA)
async def create_agent(
self,
agent_config: AgentConfig,
@ -498,8 +501,15 @@ class Agents(Protocol):
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
agent_id: str,
@ -528,8 +538,15 @@ class Agents(Protocol):
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
agent_id: str,
@ -554,8 +571,14 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
agent_id: str,
@ -574,8 +597,14 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
agent_id: str,
@ -597,8 +626,15 @@ class Agents(Protocol):
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
@ -612,7 +648,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agents_session(
self,
session_id: str,
@ -628,7 +665,10 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(
route="/agents/{agent_id}/session/{session_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agents_session(
self,
session_id: str,
@ -641,7 +681,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
@ -652,7 +693,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
@ -662,7 +704,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
@ -671,7 +714,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,

View file

@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
tools: list[MCPListToolsTool]
@json_schema_type
class OpenAIResponseMCPApprovalRequest(BaseModel):
"""
A request for human approval of a tool invocation.
"""
arguments: str
id: str
name: str
server_label: str
type: Literal["mcp_approval_request"] = "mcp_approval_request"
@json_schema_type
class OpenAIResponseMCPApprovalResponse(BaseModel):
"""
A response to an MCP approval request.
"""
approval_request_id: str
approve: bool
type: Literal["mcp_approval_response"] = "mcp_approval_response"
id: str | None = None
reason: str | None = None
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
| OpenAIResponseOutputMessageMCPListTools
| OpenAIResponseMCPApprovalRequest,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -723,6 +750,8 @@ OpenAIResponseInput = Annotated[
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
|
# Fallback to the generic message type as a last resort
OpenAIResponseMessage,

View file

@ -1030,7 +1030,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/chat-completion", method="POST", level=LLAMA_STACK_API_V1)
async def chat_completion(
self,
model_id: str,

View file

@ -318,7 +318,8 @@ class VectorStoreChunkingStrategyStatic(BaseModel):
VectorStoreChunkingStrategy = Annotated[
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
Field(discriminator="type"),
]
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
@ -427,6 +428,44 @@ class VectorStoreFileDeleteResponse(BaseModel):
deleted: bool = True
@json_schema_type
class VectorStoreFileBatchObject(BaseModel):
"""OpenAI Vector Store File Batch object.
:param id: Unique identifier for the file batch
:param object: Object type identifier, always "vector_store.file_batch"
:param created_at: Timestamp when the file batch was created
:param vector_store_id: ID of the vector store containing the file batch
:param status: Current processing status of the file batch
:param file_counts: File processing status counts for the batch
"""
id: str
object: str = "vector_store.file_batch"
created_at: int
vector_store_id: str
status: VectorStoreFileStatus
file_counts: VectorStoreFileCounts
@json_schema_type
class VectorStoreFilesListInBatchResponse(BaseModel):
"""Response from listing files in a vector store file batch.
:param object: Object type identifier, always "list"
:param data: List of vector store file objects in the batch
:param first_id: (Optional) ID of the first file in the list for pagination
:param last_id: (Optional) ID of the last file in the list for pagination
:param has_more: Whether there are more files available beyond this page
"""
object: str = "list"
data: list[VectorStoreFileObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@ -529,7 +568,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store(
self,
vector_store_id: str,
@ -547,7 +590,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
@ -559,7 +606,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/search", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
@ -568,7 +619,9 @@ class VectorIO(Protocol):
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
search_mode: (
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
@ -585,7 +638,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/files", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
@ -603,7 +660,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/files", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
@ -625,7 +686,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
@ -657,7 +722,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/files/{file_id}", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
@ -673,7 +742,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector_stores/{vector_store_id}/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
@ -686,3 +759,89 @@ class VectorIO(Protocol):
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.
:param vector_store_id: The ID of the vector store to create the file batch for.
:param file_ids: A list of File IDs that the vector store should use.
:param attributes: (Optional) Key-value attributes to store with the files.
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
:returns: A VectorStoreFileBatchObject representing the created file batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch.
:param batch_id: The ID of the file batch to retrieve.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the file batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch.
:param batch_id: The ID of the file batch to list files from.
:param vector_store_id: The ID of the vector store containing the file batch.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancels a vector store file batch.
:param batch_id: The ID of the file batch to cancel.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
"""
...

View file

@ -8,9 +8,7 @@ import asyncio
import uuid
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -19,9 +17,11 @@ from llama_stack.apis.vector_io import (
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
@ -193,7 +193,10 @@ class VectorIORouter(VectorIO):
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
before_index = next(
(i for i, store in enumerate(all_stores) if store.id == before),
len(all_stores),
)
all_stores = all_stores[:before_index]
# Apply limit
@ -363,3 +366,61 @@ class VectorIORouter(VectorIO):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
return await self.routing_table.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -159,7 +159,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:

View file

@ -50,7 +50,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -46,7 +46,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -61,7 +61,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -51,7 +51,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -53,7 +53,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -48,7 +48,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -81,7 +81,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:

View file

@ -159,7 +159,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:

View file

@ -159,7 +159,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:

View file

@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
)
# Create orchestrator and delegate streaming logic

View file

@ -10,10 +10,12 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -127,13 +129,16 @@ class StreamingResponseOrchestrator:
messages = self.ctx.messages.copy()
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=self.ctx.response_format,
response_format=response_format,
)
# Process streaming chunks and build complete response
@ -147,10 +152,17 @@ class StreamingResponseOrchestrator:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
@ -194,10 +206,11 @@ class StreamingResponseOrchestrator:
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
@ -208,9 +221,23 @@ class StreamingResponseOrchestrator:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -649,3 +676,46 @@ class StreamingResponseOrchestrator:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None

View file

@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)

View file

@ -9,7 +9,7 @@ import uuid
from pathlib import Path
from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -87,7 +88,7 @@ class LocalfsFilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:

View file

@ -290,13 +290,13 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
content = response.completion_message.content
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)

View file

@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
description="The service name to use for telemetry",
)
sinks: list[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
default=[TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
)
sqlite_db_path: str = Field(
@ -49,7 +49,7 @@ class TelemetryConfig(BaseModel):
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
"sinks": "${env.TELEMETRY_SINKS:=sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
}

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -195,7 +196,7 @@ class S3FilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"

View file

@ -44,8 +44,8 @@ client.initialize()
The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "system",
@ -57,11 +57,9 @@ response = client.inference.chat_completion(
},
],
stream=False,
sampling_params={
"max_tokens": 50,
},
max_tokens=50,
)
print(f"Response: {response.completion_message.content}")
print(f"Response: {response.choices[0].message.content}")
```
### Tool Calling Example ###
@ -89,15 +87,15 @@ tool_definition = ToolDefinition(
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
tool_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Response: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
@ -123,8 +121,8 @@ response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
structured_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
@ -134,7 +132,7 @@ structured_response = client.inference.chat_completion(
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
### Create Embeddings
@ -167,8 +165,8 @@ def load_image_as_base64(image_path):
image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.inference.chat_completion(
model_id="nvidia/vila",
vlm_response = client.chat.completions.create(
model="nvidia/vila",
messages=[
{
"role": "user",
@ -188,5 +186,5 @@ vlm_response = client.inference.chat_completion(
],
)
print(f"VLM Response: {vlm_response.completion_message.content}")
print(f"VLM Response: {vlm_response.choices[0].message.content}")
```

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from fastapi import Request
from pydantic import BaseModel, ValidationError
from llama_stack.apis.files import ExpiresAfter
async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None:
"""
Generic parser to extract a Pydantic model from multipart form data.
Handles both bracket notation (field[attr1], field[attr2]) and JSON string format.
Args:
request: The FastAPI request object
field_name: The name of the field in the form data (e.g., "expires_after")
model_class: The Pydantic model class to parse into
Returns:
An instance of model_class if parsing succeeds, None otherwise
Example:
expires_after = await parse_pydantic_from_form(
request, "expires_after", ExpiresAfter
)
"""
form = await request.form()
# Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds])
bracket_data = {}
prefix = f"{field_name}["
for key in form.keys():
if key.startswith(prefix) and key.endswith("]"):
# Extract the attribute name from field_name[attr]
attr = key[len(prefix) : -1]
bracket_data[attr] = form[key]
if bracket_data:
try:
return model_class(**bracket_data)
except (ValidationError, TypeError):
pass
# Check for JSON string format
if field_name in form:
value = form[field_name]
if isinstance(value, str):
try:
data = json.loads(value)
return model_class(**data)
except (json.JSONDecodeError, TypeError, ValidationError):
pass
return None
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
"""
Dependency to parse expires_after from multipart form data.
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
and JSON string format.
"""
return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter)

View file

@ -24,11 +24,13 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyStatic,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileLastError,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
@ -107,7 +109,11 @@ class OpenAIVectorStoreMixin(ABC):
self.openai_vector_stores.pop(store_id, None)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
self,
store_id: str,
file_id: str,
file_info: dict[str, Any],
file_contents: list[dict[str, Any]],
) -> None:
"""Save vector store file metadata to persistent storage."""
assert self.kvstore
@ -301,7 +307,10 @@ class OpenAIVectorStoreMixin(ABC):
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
before_index = next(
(i for i, store in enumerate(all_stores) if store["id"] == before),
len(all_stores),
)
all_stores = all_stores[:before_index]
# Apply limit
@ -397,7 +406,9 @@ class OpenAIVectorStoreMixin(ABC):
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
search_mode: (
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
max_num_results = max_num_results or 10
@ -685,7 +696,10 @@ class OpenAIVectorStoreMixin(ABC):
file_objects = file_objects[after_index + 1 :]
if before:
before_index = next((i for i, file in enumerate(file_objects) if file.id == before), len(file_objects))
before_index = next(
(i for i, file in enumerate(file_objects) if file.id == before),
len(file_objects),
)
file_objects = file_objects[:before_index]
# Apply limit
@ -805,3 +819,42 @@ class OpenAIVectorStoreMixin(ABC):
id=file_id,
deleted=True,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch."""
raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet")
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch."""
raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet")
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch."""
raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet")
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancel a vector store file batch."""
raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet")