mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into remove-deprecated-chat-completion
This commit is contained in:
commit
ee6a502289
209 changed files with 109297 additions and 8828 deletions
|
@ -772,6 +772,12 @@ class Agents(Protocol):
|
|||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_openai_response(
|
||||
self,
|
||||
|
@ -784,6 +790,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_openai_response(
|
||||
self,
|
||||
|
@ -809,6 +816,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
@ -828,10 +836,9 @@ class Agents(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/responses/{response_id}/input_items",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
|
@ -853,6 +860,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
|
|
|
@ -43,6 +43,7 @@ class Batches(Protocol):
|
|||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
|
@ -63,6 +64,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
@ -72,6 +74,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
@ -81,6 +84,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
|
|
|
@ -105,6 +105,7 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
|
@ -127,6 +128,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_files(
|
||||
self,
|
||||
|
@ -146,6 +148,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
|
@ -159,6 +162,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
|
@ -172,6 +176,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
|
|
|
@ -27,14 +27,12 @@ from llama_stack.models.llama.datatypes import (
|
|||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
from enum import StrEnum
|
||||
|
@ -1027,6 +1025,7 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
@ -1078,6 +1077,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -1134,6 +1134,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
@ -1163,6 +1164,7 @@ class Inference(InferenceProvider):
|
|||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
@ -1181,6 +1183,9 @@ class Inference(InferenceProvider):
|
|||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
|
|
@ -111,6 +111,14 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_model(
|
||||
self,
|
||||
|
|
|
@ -114,6 +114,7 @@ class Safety(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
|
@ -19,59 +19,23 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter definition for a tool.
|
||||
|
||||
:param name: Name of the parameter
|
||||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param items: Type of the elements when parameter_type is array
|
||||
:param title: (Optional) Title of the parameter
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
items: dict | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
"""A tool that can be invoked by agents.
|
||||
|
||||
:param type: Type of resource, always 'tool'
|
||||
:param toolgroup_id: ID of the tool group this tool belongs to
|
||||
:param description: Human-readable description of what the tool does
|
||||
:param parameters: List of parameters this tool accepts
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
description: str
|
||||
parameters: list[ToolParameter]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Human-readable description of what the tool does
|
||||
:param parameters: (Optional) List of parameters this tool accepts
|
||||
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
|
||||
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
|
||||
"""
|
||||
|
||||
toolgroup_id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: list[ToolParameter] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
@ -122,7 +86,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool(self, tool_name: str) -> ToolDef: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
|
@ -135,15 +99,6 @@ class ListToolGroupsResponse(BaseModel):
|
|||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
"""Response containing a list of tools.
|
||||
|
||||
:param data: List of tools
|
||||
"""
|
||||
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
|
@ -194,11 +149,11 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -206,11 +161,11 @@ class ToolGroups(Protocol):
|
|||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool:
|
||||
) -> ToolDef:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
:returns: A ToolDef.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
@ -512,6 +512,7 @@ class VectorIO(Protocol):
|
|||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
|
@ -538,6 +539,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
|
@ -556,6 +558,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
|
@ -568,6 +573,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="POST",
|
||||
|
@ -590,6 +598,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="DELETE",
|
||||
|
@ -606,6 +617,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
|
@ -638,6 +655,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
|
@ -660,6 +683,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
|
@ -686,6 +715,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
|
@ -704,6 +739,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
|
@ -722,6 +763,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
|
@ -742,6 +789,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
|
@ -765,6 +818,12 @@ class VectorIO(Protocol):
|
|||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -787,6 +846,12 @@ class VectorIO(Protocol):
|
|||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
|
@ -800,6 +865,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
|
@ -828,6 +899,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
|
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
|
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
|
|||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolsResponse,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
|
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolsResponse:
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
@ -68,30 +68,19 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
t.toolgroup_id = toolgroup.identifier
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
|
||||
for tool in tooldefs:
|
||||
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
@ -102,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
async def get_tool(self, tool_name: str) -> ToolDef:
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
@ -132,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
await self.unregister_object(await self.get_tool_group(toolgroup_id))
|
||||
|
|
|
@ -257,7 +257,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v9"
|
||||
KEY_VERSION = "v10"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ def tool_chat_page():
|
|||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
|
|
@ -31,7 +31,14 @@ CATEGORIES = [
|
|||
"client",
|
||||
"telemetry",
|
||||
"openai_responses",
|
||||
"testing",
|
||||
"providers",
|
||||
"models",
|
||||
"files",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
]
|
||||
UNCATEGORIZED = "uncategorized"
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: dict[str, int] = dict.fromkeys(CATEGORIES, DEFAULT_LOG_LEVEL)
|
||||
|
@ -165,7 +172,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
record.category = UNCATEGORIZED # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
|
@ -255,7 +262,10 @@ def get_logger(
|
|||
log_level = _category_levels[root_category]
|
||||
else:
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
logging.warning(f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}")
|
||||
if category != UNCATEGORIZED:
|
||||
logging.warning(
|
||||
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
|
||||
)
|
||||
logger.setLevel(log_level)
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
|
|
@ -37,14 +37,7 @@ RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: str | dict[str, RecursiveType]
|
||||
arguments_json: str | None = None
|
||||
arguments: str
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -88,19 +81,11 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: str | None = None
|
||||
required: bool | None = True
|
||||
items: Any | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: BuiltinTool | str
|
||||
description: str | None = None
|
||||
parameters: dict[str, ToolParamDefinition] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -232,8 +232,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
@ -18,7 +18,6 @@ from typing import Any
|
|||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
@ -101,11 +100,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set required_params = t.input_schema.get('required', []) -%}
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
@ -114,11 +110,11 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
{
|
||||
"{{name}}": {
|
||||
"type": "object",
|
||||
"description": "{{param.description}}"
|
||||
"description": "{{param.get('description', '')}}"
|
||||
}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
|
@ -143,17 +139,19 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -170,11 +168,14 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set modified_params = t.parameters.copy() -%}
|
||||
{%- for key, value in modified_params.items() -%}
|
||||
{%- if 'default' in value -%}
|
||||
{%- set _ = value.pop('default', None) -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set modified_params = {} -%}
|
||||
{%- for key, value in tprops.items() -%}
|
||||
{%- set param_copy = value.copy() -%}
|
||||
{%- if 'default' in param_copy -%}
|
||||
{%- set _ = param_copy.pop('default', None) -%}
|
||||
{%- endif -%}
|
||||
{%- set _ = modified_params.update({key: param_copy}) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tparams = modified_params | tojson -%}
|
||||
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||
|
@ -205,17 +206,19 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -255,11 +258,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = (t.input_schema or {}).get('properties', {}) -%}
|
||||
{%- set required_params = (t.input_schema or {}).get('required', []) -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
|
@ -267,11 +267,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
"type": "{{param.get('type', 'string')}}",
|
||||
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
|
||||
"default": "{{param.get('default')}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
|
@ -299,18 +299,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for",
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -220,17 +220,18 @@ class ToolUtils:
|
|||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
args = json.loads(t.arguments)
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
return args["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
|
@ -239,12 +240,11 @@ class ToolUtils:
|
|||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
"parameters": args,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
return f"<function={fname}>{t.arguments}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
|
@ -260,7 +260,7 @@ class ToolUtils:
|
|||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in args.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -184,7 +185,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -185,7 +186,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
@ -298,8 +298,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import textwrap
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.apis.inference import ToolDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
|
@ -81,11 +81,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set required_params = t.input_schema.get('required', []) -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
|
@ -93,11 +90,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
"type": "{{param.get('type', 'string')}}",
|
||||
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
|
||||
"default": "{{param.get('default')}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
|
@ -119,18 +116,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for",
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -60,7 +60,6 @@ from llama_stack.apis.inference import (
|
|||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
|
@ -866,20 +865,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
|
@ -889,44 +880,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
identifier: str | BuiltinTool | None = tool_def.name
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
if input_tool_name in (None, tool_def.name):
|
||||
identifier = tool_def.name
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
@ -970,12 +951,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
|
||||
try:
|
||||
args = json.loads(tool_call.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ from .utils import (
|
|||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::responses")
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
|
|
|
@ -62,22 +62,13 @@ def convert_tooldef_to_chat_tool(tool_def):
|
|||
ChatCompletionToolParam suitable for OpenAI chat completion
|
||||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
internal_tool_def = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
items=param.items,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(internal_tool_def)
|
||||
|
||||
|
@ -528,23 +519,15 @@ class StreamingResponseOrchestrator:
|
|||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
|
@ -631,16 +614,11 @@ class StreamingResponseOrchestrator:
|
|||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
input_schema=t.input_schema
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
@ -68,9 +68,7 @@ public class FunctionTagCustomToolGenerator {
|
|||
{
|
||||
"name": "{{t.tool_name}}",
|
||||
"description": "{{t.description}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"properties": { {{t.parameters}} }
|
||||
"input_schema": { {{t.input_schema}} }
|
||||
}
|
||||
|
||||
{{/let}}
|
||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
@ -301,13 +300,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for. Can be a natural language sentence or keywords.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -82,8 +82,7 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -93,18 +92,6 @@ def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]
|
|||
compat_tools = []
|
||||
|
||||
for tool in tools:
|
||||
properties = {}
|
||||
compat_required = []
|
||||
if tool.parameters:
|
||||
for tool_key, tool_param in tool.parameters.items():
|
||||
properties[tool_key] = {"type": tool_param.param_type}
|
||||
if tool_param.description:
|
||||
properties[tool_key]["description"] = tool_param.description
|
||||
if tool_param.default:
|
||||
properties[tool_key]["default"] = tool_param.default
|
||||
if tool_param.required:
|
||||
compat_required.append(tool_key)
|
||||
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
tool_name = tool.tool_name
|
||||
|
@ -116,10 +103,11 @@ def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]
|
|||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"parameters": tool.input_schema
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": compat_required,
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -154,7 +142,6 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
|
||||
args_str = tool_call_buf.arguments or "{}"
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
chunks.append(
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -163,8 +150,7 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
tool_call=ToolCall(
|
||||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
arguments_json=args_str,
|
||||
arguments=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -57,13 +56,16 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -56,13 +55,16 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -56,13 +55,16 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -57,13 +56,16 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to compute",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ from ..sqlstore.api import ColumnDefinition, ColumnType
|
|||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="inference_store")
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
|
|
|
@ -125,7 +125,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
|
@ -537,18 +536,13 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
# arguments_json can be None, so attempt it first and fall back to arguments
|
||||
if hasattr(tc, "arguments_json") and tc.arguments_json:
|
||||
arguments = tc.arguments_json
|
||||
else:
|
||||
arguments = json.dumps(tc.arguments)
|
||||
result["tool_calls"].append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@ -641,7 +635,7 @@ async def convert_message_to_openai_dict_new(
|
|||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
arguments=tool.arguments, # Already a JSON string, don't double-encode
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
|
@ -684,8 +678,7 @@ def convert_tool_call(
|
|||
valid_tool_call = ToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
arguments_json=tool_call.function.arguments,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
except Exception:
|
||||
return UnparseableToolCall(
|
||||
|
@ -745,14 +738,8 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
ToolDefinition:
|
||||
tool_name: str | BuiltinTool
|
||||
description: Optional[str]
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
|
||||
ToolParamDefinition:
|
||||
param_type: str
|
||||
description: Optional[str]
|
||||
required: Optional[bool]
|
||||
default: Optional[Any]
|
||||
|
||||
input_schema: Optional[Dict[str, Any]] # JSON Schema
|
||||
output_schema: Optional[Dict[str, Any]] # JSON Schema (not used by OpenAI)
|
||||
|
||||
OpenAI spec -
|
||||
|
||||
|
@ -761,20 +748,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
"function": {
|
||||
"name": tool_name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param_name: {
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"default": default,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": [param_name, ...],
|
||||
},
|
||||
"parameters": {<JSON Schema>},
|
||||
},
|
||||
}
|
||||
|
||||
NOTE: OpenAI does not support output_schema, so it is dropped here.
|
||||
"""
|
||||
out = {
|
||||
"type": "function",
|
||||
|
@ -783,37 +761,19 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
|
||||
function["name"] = tool.tool_name.value
|
||||
else:
|
||||
function.update(name=tool.tool_name)
|
||||
function["name"] = tool.tool_name
|
||||
|
||||
if tool.description:
|
||||
function.update(description=tool.description)
|
||||
function["description"] = tool.description
|
||||
|
||||
if tool.parameters:
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = to_openai_param_type(param.param_type)
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
properties[param_name].update(default=param.default)
|
||||
if param.items:
|
||||
properties[param_name].update(items=param.items)
|
||||
if param.title:
|
||||
properties[param_name].update(title=param.title)
|
||||
if param.required:
|
||||
required.append(param_name)
|
||||
if tool.input_schema:
|
||||
# Pass through the entire JSON Schema as-is
|
||||
function["parameters"] = tool.input_schema
|
||||
|
||||
if required:
|
||||
parameters.update(required=required)
|
||||
|
||||
function.update(parameters=parameters)
|
||||
# NOTE: OpenAI does not support output_schema, so we drop it here
|
||||
# It's stored in LlamaStack for validation and other provider usage
|
||||
|
||||
return out
|
||||
|
||||
|
@ -874,22 +834,12 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
|
|||
tool_fn = tool.get("function", {})
|
||||
tool_name = tool_fn.get("name", None)
|
||||
tool_desc = tool_fn.get("description", None)
|
||||
|
||||
tool_params = tool_fn.get("parameters", None)
|
||||
lls_tool_params = {}
|
||||
if tool_params is not None:
|
||||
tool_param_properties = tool_params.get("properties", {})
|
||||
for tool_param_key, tool_param_value in tool_param_properties.items():
|
||||
tool_param_def = ToolParamDefinition(
|
||||
param_type=str(tool_param_value.get("type", None)),
|
||||
description=tool_param_value.get("description", None),
|
||||
)
|
||||
lls_tool_params[tool_param_key] = tool_param_def
|
||||
|
||||
lls_tool = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool_desc,
|
||||
parameters=lls_tool_params,
|
||||
input_schema=tool_params, # Pass through entire JSON Schema
|
||||
)
|
||||
lls_tools.append(lls_tool)
|
||||
return lls_tools
|
||||
|
@ -939,8 +889,7 @@ def _convert_openai_tool_calls(
|
|||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -1222,12 +1171,10 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
|
||||
try:
|
||||
arguments = json.loads(buffer["arguments"])
|
||||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=arguments,
|
||||
arguments_json=buffer["arguments"],
|
||||
arguments=buffer["arguments"],
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -1390,7 +1337,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
arguments=tool_call.arguments_json,
|
||||
arguments=tool_call.arguments,
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
|
|
|
@ -286,34 +286,34 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
|
||||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
resp = await self.client.chat.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
resp = await self.client.chat.completions.create(**params)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_embeddings(
|
||||
|
|
|
@ -25,7 +25,7 @@ from ..sqlstore.api import ColumnDefinition, ColumnType
|
|||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="responses_store")
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
|
|
|
@ -20,7 +20,6 @@ from llama_stack.apis.tools import (
|
|||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
)
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -113,24 +112,12 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
async with client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
required="default" not in param_schema,
|
||||
items=param_schema.get("items", None),
|
||||
title=param_schema.get("title", None),
|
||||
default=param_schema.get("default", None),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
input_schema=tool.inputSchema,
|
||||
output_schema=getattr(tool, "outputSchema", None),
|
||||
metadata={
|
||||
"endpoint": endpoint,
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue