mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 12:09:40 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
f5cb965f0f
226 changed files with 16519 additions and 8666 deletions
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.apis.tools import ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
|||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||
|
||||
|
||||
class ToolsProtocolPrivate(Protocol):
|
||||
async def register_tool(self, tool: Tool) -> None: ...
|
||||
class ToolGroupsProtocolPrivate(Protocol):
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -20,9 +20,12 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Order,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -39,6 +42,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
@ -66,15 +70,17 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl = None
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
self.responses_store = ResponsesStore(self.config.responses_store)
|
||||
await self.responses_store.initialize()
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
self.persistence_store,
|
||||
inference_api=self.inference_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
@ -305,14 +311,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
# OpenAI responses
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.get_openai_response(id)
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
|
@ -320,5 +327,27 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, previous_response_id, store, stream, temperature, tools
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
|
|
@ -10,10 +10,12 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
responses_store: SqlStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
|||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="agents_store.db",
|
||||
)
|
||||
),
|
||||
"responses_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="responses_store.db",
|
||||
),
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, cast
|
||||
|
@ -12,24 +13,29 @@ from typing import Any, cast
|
|||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
@ -49,11 +55,12 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
@ -162,41 +169,43 @@ async def _get_message_type_by_role(role: str):
|
|||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: OpenAIResponseInputItemList
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
stream: bool
|
||||
temperature: float | None
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
persistence_store: KVStore,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
):
|
||||
self.persistence_store = persistence_store
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
self.responses_store = responses_store
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input_items.data
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.response.output)
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
|
@ -208,99 +217,60 @@ class OpenAIResponsesImpl:
|
|||
|
||||
return input
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(id)
|
||||
return response_with_input.response
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def create_openai_response(
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
if stream:
|
||||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
async for chunk in chat_response:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# TODO: this only works for text content
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
# Ensure we don't have any empty type field in the tool call dict.
|
||||
# The OpenAI client used by providers often returns a type=None here.
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
else:
|
||||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _process_response_choices(
|
||||
self,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
ctx: ChatCompletionContext,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
"""Handle tool execution and response message creation."""
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Execute tool calls if any
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||
if tools[0].type == "function":
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
|
@ -312,11 +282,133 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||
)
|
||||
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||
output_messages.extend(tool_messages)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
return output_messages
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
original_input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(original_input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=original_input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in original_input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
original_input = input # Keep reference for storage
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Tool setup
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
if mcp_list_message:
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._create_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
original_input=original_input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
return await self._create_non_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
original_input=original_input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
async def _create_non_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
original_input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> OpenAIResponseObject:
|
||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -327,57 +419,168 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
# Store response if requested
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
await self._store_response(
|
||||
response=response,
|
||||
original_input=original_input,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=prev_response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# TODO: response created should actually get emitted much earlier in the process
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=response)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
|
||||
|
||||
return async_response()
|
||||
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
original_input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
)
|
||||
|
||||
# Emit response.created immediately
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# For streaming, inference_result is an async iterator of chunks
|
||||
# Stream chunks and emit delta events as they arrive
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response_obj = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response_obj,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
if store:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
original_input=original_input,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
) -> tuple[
|
||||
list[ChatCompletionToolParam],
|
||||
dict[str, OpenAIResponseInputToolMCP],
|
||||
OpenAIResponseOutput | None,
|
||||
]:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool
|
||||
|
||||
mcp_tool_to_server = {}
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
mcp_list_message = None
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
|
@ -386,91 +589,95 @@ class OpenAIResponsesImpl:
|
|||
elif input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
if isinstance(input_tool.allowed_tools, list):
|
||||
always_allowed = input_tool.allowed_tools
|
||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = input_tool.allowed_tools.always
|
||||
never_allowed = input_tool.allowed_tools.never
|
||||
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=input_tool.server_url,
|
||||
headers=input_tool.headers or {},
|
||||
)
|
||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
chat_tools.append(chat_tool)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
server_label=input_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
chat_tools.append(make_openai_tool(t.name, t))
|
||||
if t.name in mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||
mcp_tool_to_server[t.name] = input_tool
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
choice: OpenAIChoice,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages
|
||||
|
||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Copy the messages list to avoid mutating the original list
|
||||
messages = messages.copy()
|
||||
next_turn_messages = ctx.messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
||||
if not function or not tool_call_id:
|
||||
continue
|
||||
|
||||
# TODO: telemetry spans for tool calls
|
||||
result = await self._execute_tool_call(function)
|
||||
|
||||
# Handle tool call failure
|
||||
if not result:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
|
||||
result_content = ""
|
||||
# TODO: handle other result content types and lists
|
||||
if isinstance(result.content, str):
|
||||
result_content = result.content
|
||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if further_input:
|
||||
next_turn_messages.append(further_input)
|
||||
|
||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
model=ctx.model,
|
||||
messages=next_turn_messages,
|
||||
stream=ctx.stream,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
|
||||
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
|
@ -480,15 +687,86 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> ToolInvocationResult | None:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=function_args,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
|
||||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl(
|
|||
first = token_results[0]
|
||||
if not first.finished and not first.ignore_token:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, "cyan", end="")
|
||||
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", "magenta", end="")
|
||||
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
|
@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl(
|
|||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", "magenta", end="")
|
||||
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
|
|
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
|||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
|
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
|
@ -146,7 +148,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type,
|
||||
name=event.type.value,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
|
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
|
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
event.attributes["__root_span__"] = "true"
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
|||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def insert(
|
||||
|
@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
query=query,
|
||||
params={
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"mode": query_config.mode,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
|
@ -146,7 +147,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
tokens += metadata["token_count"]
|
||||
tokens += metadata["metadata_token_count"]
|
||||
tokens += metadata.get("metadata_token_count", 0)
|
||||
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
|
|
|
@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex):
|
|||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for d, i in zip(distances[0], indices[0], strict=False):
|
||||
|
@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
|
|
|
@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
KEYWORD_SEARCH = "keyword"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||
|
||||
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
"""Serialize a list of floats into a compact binary representation."""
|
||||
|
@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
Two tables are used:
|
||||
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
||||
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||
|
@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
self.bank_id = bank_id
|
||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||
|
@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
|
||||
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||
# during initialization to make it easier to create the table that is required.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||
USING fts5(id, content);
|
||||
""")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
|
@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
Also inserts chunk content into FTS table for keyword search support.
|
||||
"""
|
||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||
|
||||
|
@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur = connection.cursor()
|
||||
|
||||
try:
|
||||
# Start transaction a single transcation for all batches
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
|
||||
# Insert metadata
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
|
@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||
serialize_vector(emb.tolist()),
|
||||
(
|
||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
connection.rollback() # Rollback on failure
|
||||
connection.rollback()
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
raise
|
||||
|
||||
|
@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur.close()
|
||||
connection.close()
|
||||
|
||||
# Process all batches in a single thread
|
||||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
|
||||
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
try:
|
||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
|
@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_execute_query)
|
||||
|
||||
chunks, scores = [], []
|
||||
for _id, chunk_json, distance in rows:
|
||||
for row in rows:
|
||||
_id, chunk_json, distance = row
|
||||
score = 1.0 / distance if distance != 0 else float("inf")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
"""
|
||||
cur.execute(query_sql, (query_string, k))
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_execute_query)
|
||||
chunks, scores = [], []
|
||||
for row in rows:
|
||||
_id, chunk_json, score = row
|
||||
# BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
|
||||
# This design is intentional to simplify sorting by ascending score.
|
||||
# Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
|
||||
if score > -score_threshold:
|
||||
continue
|
||||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
|
||||
score = 1.0 / distance if distance != 0 else float("inf")
|
||||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
|
@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
if prompt_logprobs is not None:
|
||||
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=(await self.model_store.get_model(model)).provider_resource_id,
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
|
@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=(await self.model_store.get_model(model)).provider_resource_id,
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool = Field(
|
||||
tls_verify: bool | str = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates",
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
def validate_tls_verify(cls, v):
|
||||
if isinstance(v, str):
|
||||
# Check if it's a boolean string
|
||||
if v.lower() in ("true", "false"):
|
||||
return v.lower() == "true"
|
||||
# Otherwise, treat it as a cert path
|
||||
cert_path = Path(v).expanduser().resolve()
|
||||
if not cert_path.exists():
|
||||
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||
if not cert_path.is_file():
|
||||
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
|
|
|
@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
|
|
|
@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
Parameters:
|
||||
training_config: TrainingConfig - Configuration for training
|
||||
model: str - Model identifier
|
||||
model: str - NeMo Customizer configuration name
|
||||
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||
job_uuid: str - Unique identifier for the job, ignored atm
|
||||
|
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
User is informed about unsupported parameters via warnings.
|
||||
"""
|
||||
# Map model to nvidia model name
|
||||
# See `_MODEL_ENTRIES` for supported models
|
||||
nvidia_model = self.get_provider_model_id(model)
|
||||
|
||||
# Check for unsupported method parameters
|
||||
unsupported_method_params = []
|
||||
|
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
# Prepare base job configuration
|
||||
job_config = {
|
||||
"config": nvidia_model,
|
||||
"config": model,
|
||||
"dataset": {
|
||||
"name": training_config["data_config"]["dataset_id"],
|
||||
"namespace": self.config.dataset_namespace,
|
||||
|
|
18
llama_stack/providers/remote/safety/sambanova/__init__.py
Normal file
18
llama_stack/providers/remote/safety/sambanova/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
|
||||
from .sambanova import SambaNovaSafetyAdapter
|
||||
|
||||
impl = SambaNovaSafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
37
llama_stack/providers/remote/safety/sambanova/config.py
Normal file
37
llama_stack/providers/remote/safety/sambanova/config.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
100
llama_stack/providers/remote/safety/sambanova/sambanova.py
Normal file
100
llama_stack/providers/remote/safety/sambanova/sambanova.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
||||
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.sambanova_api_key:
|
||||
raise ValueError(
|
||||
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
|
||||
)
|
||||
return provider_data.sambanova_api_key
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
list_models_url = self.config.url + "/models"
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if (
|
||||
len(available_models) == 0
|
||||
or "guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
||||
):
|
||||
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = litellm.completion(
|
||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||
)
|
||||
shield_message = response.choices[0].message.content
|
||||
|
||||
if "unsafe" in shield_message.lower():
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
violation_type = shield_message.split("\n")[-1]
|
||||
metadata = {"violation_type": violation_type}
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse()
|
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
|
||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BingSearchToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -4,18 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
|
||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
async def get_adapter_impl(config: MCPProviderConfig, _deps):
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -9,7 +9,12 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
class MCPProviderDataValidator(BaseModel):
|
||||
# mcp_endpoint => dict of headers to send
|
||||
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||
|
||||
|
||||
class MCPProviderConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -7,61 +7,45 @@
|
|||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: ModelContextProtocolConfig):
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
tools = []
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
|
@ -71,12 +55,19 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
headers = await self.get_headers_from_request(endpoint)
|
||||
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||
|
||||
headers = {}
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and provider_data.mcp_headers:
|
||||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
headers.update(values)
|
||||
return headers
|
||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import TavilySearchToolConfig
|
||||
|
||||
|
||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: TavilySearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
|
||||
|
||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: WolframAlphaToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
|
@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
|
@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
|
@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
|
@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(self.collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
VectorIO,
|
||||
|
|
123
llama_stack/providers/utils/inference/inference_store.py
Normal file
123
llama_stack/providers/utils/inference/inference_store.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store_config = sql_store_config
|
||||
self.sql_store = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = sqlstore_impl(self.sql_store_config)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created": ColumnType.INTEGER,
|
||||
"model": ColumnType.STRING,
|
||||
"choices": ColumnType.JSON,
|
||||
"input_messages": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
|
||||
await self.sql_store.insert(
|
||||
"chat_completions",
|
||||
{
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
},
|
||||
)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""
|
||||
List chat completions from the database.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"chat_completions",
|
||||
where={"model": model} if model else None,
|
||||
order_by=[("created", order.value)],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [
|
||||
OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
return ListOpenAIChatCompletionResponse(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
|
||||
if not row:
|
||||
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
||||
return OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
|
@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
for i, outstanding_response in enumerate(outstanding_responses):
|
||||
response = await outstanding_response
|
||||
i = 0
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
|
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
|
|
129
llama_stack/providers/utils/inference/stream_utils.py
Normal file
129
llama_stack/providers/utils/inference/stream_utils.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
||||
async def stream_and_store_openai_completion(
|
||||
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: str,
|
||||
store: InferenceStore,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
|
||||
"""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in provider_stream:
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
# Initialize with correct structure for _ToolCallBuilderData
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
# Ensure that we are extending with the correct type
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
yield chunk
|
||||
finally:
|
||||
if id:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"], # No or "function" needed, already set
|
||||
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(datetime.now(timezone.utc).timestamp()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
await store.store_chat_completion(final_response, input_messages)
|
|
@ -177,7 +177,11 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -210,9 +214,12 @@ class VectorDBWithIndex:
|
|||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
query_str = interleaved_content_as_str(query)
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
else:
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
135
llama_stack/providers/utils/responses/responses_store.py
Normal file
135
llama_stack/providers/utils/responses/responses_store.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = sqlstore_impl(sql_store_config)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"response_object": ColumnType.JSON,
|
||||
"model": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": data["id"],
|
||||
"created_at": data["created_at"],
|
||||
"model": data["model"],
|
||||
"response_object": data,
|
||||
},
|
||||
)
|
||||
|
||||
async def list_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""
|
||||
List responses from the database.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The maximum number of responses to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"openai_responses",
|
||||
where={"model": model} if model else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows]
|
||||
return ListOpenAIResponseObject(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def list_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""
|
||||
List input items for a given response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
"""
|
||||
# TODO: support after/before pagination
|
||||
if after or before:
|
||||
raise NotImplementedError("After/before pagination is not supported yet")
|
||||
if include:
|
||||
raise NotImplementedError("Include is not supported yet")
|
||||
|
||||
response_with_input = await self.get_response_object(response_id)
|
||||
input_items = response_with_input.input
|
||||
|
||||
if order == Order.desc:
|
||||
input_items = list(reversed(input_items))
|
||||
|
||||
if limit is not None and len(input_items) > limit:
|
||||
input_items = input_items[:limit]
|
||||
|
||||
return ListOpenAIResponseInputItem(data=input_items)
|
90
llama_stack/providers/utils/sqlstore/api.py
Normal file
90
llama_stack/providers/utils/sqlstore/api.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ColumnType(Enum):
|
||||
INTEGER = "INTEGER"
|
||||
STRING = "STRING"
|
||||
TEXT = "TEXT"
|
||||
FLOAT = "FLOAT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
JSON = "JSON"
|
||||
DATETIME = "DATETIME"
|
||||
|
||||
|
||||
class ColumnDefinition(BaseModel):
|
||||
type: ColumnType
|
||||
primary_key: bool = False
|
||||
nullable: bool = True
|
||||
default: Any = None
|
||||
|
||||
|
||||
class SqlStore(Protocol):
|
||||
"""
|
||||
A protocol for a SQL store.
|
||||
"""
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""
|
||||
Create a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Insert a row into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all rows from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch one row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Update a row in a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Delete a row from a table.
|
||||
"""
|
||||
pass
|
161
llama_stack/providers/utils/sqlstore/sqlite/sqlite.py
Normal file
161
llama_stack/providers/utils/sqlstore/sqlite/sqlite.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
MetaData,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
||||
from ..sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
ColumnType.STRING: String,
|
||||
ColumnType.FLOAT: Float,
|
||||
ColumnType.BOOLEAN: Boolean,
|
||||
ColumnType.DATETIME: DateTime,
|
||||
ColumnType.TEXT: Text,
|
||||
ColumnType.JSON: JSON,
|
||||
}
|
||||
|
||||
|
||||
class SqliteSqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqliteSqlStoreConfig):
|
||||
self.engine = create_async_engine(config.engine_str)
|
||||
self.metadata = MetaData()
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Mapping[str, ColumnType | ColumnDefinition],
|
||||
) -> None:
|
||||
if not schema:
|
||||
raise ValueError(f"No columns defined for table '{table}'.")
|
||||
|
||||
sqlalchemy_columns: list[Column] = []
|
||||
|
||||
for col_name, col_props in schema.items():
|
||||
col_type = None
|
||||
is_primary_key = False
|
||||
is_nullable = True # Default to nullable
|
||||
|
||||
if isinstance(col_props, ColumnType):
|
||||
col_type = col_props
|
||||
elif isinstance(col_props, ColumnDefinition):
|
||||
col_type = col_props.type
|
||||
is_primary_key = col_props.primary_key
|
||||
is_nullable = col_props.nullable
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
||||
|
||||
sqlalchemy_columns.append(
|
||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||
)
|
||||
|
||||
# Check if table already exists in metadata, otherwise define it
|
||||
if table not in self.metadata.tables:
|
||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||
else:
|
||||
sqlalchemy_table = self.metadata.tables[table]
|
||||
|
||||
# Create the table in the database if it doesn't exist
|
||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
||||
await conn.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.engine.begin() as conn:
|
||||
query = select(self.metadata.tables[table])
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(self.metadata.tables[table].c[key] == value)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if order_by:
|
||||
if not isinstance(order_by, list):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
for order in order_by:
|
||||
if not isinstance(order, tuple):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
name, order_type = order
|
||||
if order_type == "asc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].asc())
|
||||
elif order_type == "desc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
result = await conn.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
rows = await self.fetch_all(table, where, limit=1, order_by=order_by)
|
||||
if not rows:
|
||||
return None
|
||||
return rows[0]
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt, data)
|
||||
await conn.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt)
|
||||
await conn.commit()
|
72
llama_stack/providers/utils/sqlstore/sqlstore.py
Normal file
72
llama_stack/providers/utils/sqlstore/sqlstore.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from .api import SqlStore
|
||||
|
||||
|
||||
class SqlStoreType(Enum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(BaseModel):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
)
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return cls(
|
||||
type="sqlite",
|
||||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(BaseModel):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||
Field(discriminator="type", default=SqlStoreType.sqlite.value),
|
||||
]
|
||||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type == SqlStoreType.sqlite.value:
|
||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
||||
|
||||
impl = SqliteSqlStoreImpl(config)
|
||||
elif config.type == SqlStoreType.postgres.value:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
|
||||
return impl
|
|
@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core")
|
|||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
|
@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont
|
|||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
|
100
llama_stack/providers/utils/tools/mcp.py
Normal file
100
llama_stack/providers/utils/tools/mcp.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
# for python < 3.11
|
||||
import exceptiongroup
|
||||
|
||||
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": endpoint,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
|
||||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue