mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
Merge origin/main into add-missing-provider-data-impls
Resolved conflicts in: - benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema) - llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support) - llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support) - llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support) - llama_stack/providers/remote/inference/runpod/config.py (merged imports) - pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
commit
9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -68,10 +68,10 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
async def unregister_shield(self, identifier: str) -> None: ...
|
||||
|
||||
|
||||
class VectorDBsProtocolPrivate(Protocol):
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||
class VectorStoresProtocolPrivate(Protocol):
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None: ...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ from llama_stack.core.datatypes import AccessRule, Api
|
|||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
deps: dict[Api, Any],
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
@ -21,8 +26,9 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
telemetry_enabled,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
|
|
@ -66,6 +67,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
|
@ -77,7 +79,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
|
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
|
|
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
openai_stream = await self.inference_api.openai_chat_completion(params)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
|
|
|
|||
|
|
@ -28,8 +28,10 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
|
|
@ -63,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
|
|
@ -72,6 +75,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
|
|
@ -79,8 +83,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence.agent_state)
|
||||
self.responses_store = ResponsesStore(self.config.persistence.responses, self.policy)
|
||||
await self.responses_store.initialize()
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
inference_api=self.inference_api,
|
||||
|
|
@ -88,6 +92,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
|
@ -325,6 +331,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
|
@ -332,13 +339,14 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
conversation,
|
||||
store,
|
||||
stream,
|
||||
temperature,
|
||||
|
|
@ -346,7 +354,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
guardrails,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
|||
|
|
@ -8,24 +8,30 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
|
||||
|
||||
class AgentPersistenceConfig(BaseModel):
|
||||
"""Nested persistence configuration for agents."""
|
||||
|
||||
agent_state: KVStoreReference
|
||||
responses: ResponsesStoreReference
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
responses_store: SqlStoreConfig
|
||||
persistence: AgentPersistenceConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="agents_store.db",
|
||||
),
|
||||
"responses_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="responses_store.db",
|
||||
),
|
||||
"persistence": {
|
||||
"agent_state": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
).model_dump(exclude_none=True),
|
||||
"responses": ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
|
|
@ -24,11 +25,17 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -39,10 +46,11 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_guardrail_ids,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
|
@ -61,12 +69,16 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
|
@ -91,13 +103,16 @@ class OpenAIResponsesImpl:
|
|||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
tuple: (all_input for storage, messages for chat completion, tool context)
|
||||
"""
|
||||
tool_context = ToolContext(tools)
|
||||
if previous_response_id:
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
|
|
@ -108,20 +123,45 @@ class OpenAIResponsesImpl:
|
|||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
elif conversation is not None:
|
||||
conversation_items = await self.conversations_api.list_items(conversation, order="asc")
|
||||
|
||||
# Use stored messages as source of truth (like previous_response.messages)
|
||||
stored_messages = await self.responses_store.get_conversation_messages(conversation)
|
||||
|
||||
all_input = input
|
||||
if not conversation_items.data:
|
||||
# First turn - just convert the new input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
else:
|
||||
if not stored_messages:
|
||||
all_input = conversation_items.data
|
||||
if isinstance(input, str):
|
||||
all_input.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=[OpenAIResponseInputMessageContentText(text=input)]
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_input.extend(input)
|
||||
else:
|
||||
all_input = input
|
||||
|
||||
messages = stored_messages or []
|
||||
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
return all_input, messages
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
|
|
@ -201,6 +241,7 @@ class OpenAIResponsesImpl:
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
|
@ -208,17 +249,25 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
raise ValueError(
|
||||
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
|
||||
)
|
||||
|
||||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
conversation=conversation,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
|
|
@ -227,22 +276,39 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
final_response = None
|
||||
final_event_type = None
|
||||
failed_response = None
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
failed_response.error.message
|
||||
if failed_response and failed_response.error
|
||||
else "Response stream failed without error details"
|
||||
)
|
||||
raise RuntimeError(f"OpenAI response failed: {error_message}")
|
||||
|
||||
if final_response is None:
|
||||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
|
|
@ -250,15 +316,21 @@ class OpenAIResponsesImpl:
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id, conversation
|
||||
)
|
||||
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
|
@ -269,11 +341,12 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
tool_context=tool_context,
|
||||
inputs=all_input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
response_id = f"resp_{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
|
|
@ -284,22 +357,68 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type == "response.completed":
|
||||
final_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
failed_response = None
|
||||
|
||||
# Store the response if requested
|
||||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
output_items = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
|
||||
# Store and sync before yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||
if (
|
||||
stream_chunk.type in {"response.completed", "response.incomplete"}
|
||||
and final_response
|
||||
and failed_response is None
|
||||
):
|
||||
messages_to_store = list(
|
||||
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
||||
)
|
||||
if store:
|
||||
# TODO: we really should work off of output_items instead of "final_messages"
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=messages_to_store,
|
||||
)
|
||||
|
||||
if conversation:
|
||||
await self._sync_response_to_conversation(conversation, input, output_items)
|
||||
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
|
||||
|
||||
yield stream_chunk
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _sync_response_to_conversation(
|
||||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
OpenAIResponseMessage(role="user", content=[OpenAIResponseInputMessageContentText(text=input)])
|
||||
)
|
||||
elif isinstance(input, list):
|
||||
conversation_items.extend(input)
|
||||
|
||||
conversation_items.extend(output_items)
|
||||
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
|
|
|
|||
|
|
@ -13,17 +13,24 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartReasoningText,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseError,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFailed,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseIncomplete,
|
||||
OpenAIResponseObjectStreamResponseInProgress,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||
|
|
@ -31,24 +38,43 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseObjectStreamResponseReasoningTextDelta,
|
||||
OpenAIResponseObjectStreamResponseReasoningTextDone,
|
||||
OpenAIResponseObjectStreamResponseRefusalDelta,
|
||||
OpenAIResponseObjectStreamResponseRefusalDone,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
@ -84,6 +110,9 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str,
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -92,118 +121,215 @@ class StreamingResponseOrchestrator:
|
|||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
# system message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
for item in outputs:
|
||||
if hasattr(item, "model_copy"):
|
||||
cloned.append(item.model_copy(deep=True))
|
||||
else:
|
||||
cloned.append(item)
|
||||
return cloned
|
||||
|
||||
def _snapshot_response(
|
||||
self,
|
||||
status: str,
|
||||
outputs: list[OpenAIResponseOutput],
|
||||
*,
|
||||
error: OpenAIResponseError | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
status=status,
|
||||
output=self._clone_outputs(outputs),
|
||||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
# Emit response.created followed by response.in_progress to align with OpenAI streaming
|
||||
yield OpenAIResponseObjectStreamResponseCreated(
|
||||
response=self._snapshot_response("in_progress", output_messages)
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseInProgress(
|
||||
response=self._snapshot_response("in_progress", output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
if input_violation_message:
|
||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||
yield await self._create_refusal_response(input_violation_message)
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
final_status = "completed"
|
||||
last_completion_result: ChatCompletionResult | None = None
|
||||
|
||||
while True:
|
||||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = (
|
||||
None if getattr(self.ctx.response_format, "type", None) == "text" else self.ctx.response_format
|
||||
)
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
completion_result = await self.inference_api.openai_chat_completion(params)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
# If violation detected, skip the rest of processing since we already sent refusal
|
||||
if self.violation_detected:
|
||||
return
|
||||
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
last_completion_result = completion_result_data
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
approvals,
|
||||
next_turn_messages,
|
||||
) = self._separate_tool_calls(current_response, messages)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
self.citation_files,
|
||||
message_id=completion_result_data.message_item_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield evt
|
||||
yield stream_event
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||
messages = next_turn_messages
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield stream_event
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(
|
||||
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
||||
)
|
||||
final_status = "incomplete"
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
if last_completion_result and last_completion_result.finish_reason == "length":
|
||||
final_status = "incomplete"
|
||||
|
||||
messages = next_turn_messages
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.final_messages = messages.copy()
|
||||
self.sequence_number += 1
|
||||
error = OpenAIResponseError(code="internal_error", message=str(exc))
|
||||
failure_response = self._snapshot_response("failed", output_messages, error=error)
|
||||
yield OpenAIResponseObjectStreamResponseFailed(
|
||||
response=failure_response,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
return
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
self.final_messages = messages.copy()
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
if final_status == "incomplete":
|
||||
self.sequence_number += 1
|
||||
final_response = self._snapshot_response("incomplete", output_messages)
|
||||
yield OpenAIResponseObjectStreamResponseIncomplete(
|
||||
response=final_response,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
else:
|
||||
final_response = self._snapshot_response("completed", output_messages)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
|
|
@ -232,14 +358,183 @@ class StreamingResponseOrchestrator:
|
|||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
|
||||
"""Accumulate usage from a streaming chunk into the response usage format."""
|
||||
if not chunk.usage:
|
||||
return
|
||||
|
||||
if self.accumulated_usage is None:
|
||||
# Convert from chat completion format to response format
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Accumulate across multiple inference calls
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
|
||||
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
||||
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
||||
# Use latest non-null details
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else self.accumulated_usage.input_tokens_details
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else self.accumulated_usage.output_tokens_details
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_reasoning_content_chunk(
|
||||
self,
|
||||
reasoning_content: str,
|
||||
reasoning_part_emitted: bool,
|
||||
reasoning_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Emit content_part.added event for first reasoning chunk
|
||||
if not reasoning_part_emitted:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=reasoning_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartReasoningText(
|
||||
text="", # Will be filled incrementally via reasoning deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit reasoning_text.delta event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseReasoningTextDelta(
|
||||
content_index=reasoning_content_index,
|
||||
delta=reasoning_content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _handle_refusal_content_chunk(
|
||||
self,
|
||||
refusal_content: str,
|
||||
refusal_part_emitted: bool,
|
||||
refusal_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Emit content_part.added event for first refusal chunk
|
||||
if not refusal_part_emitted:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=refusal_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartRefusal(
|
||||
refusal="", # Will be filled incrementally via refusal deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit refusal.delta event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseRefusalDelta(
|
||||
content_index=refusal_content_index,
|
||||
delta=refusal_content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _emit_reasoning_done_events(
|
||||
self,
|
||||
reasoning_text_accumulated: list[str],
|
||||
reasoning_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
final_reasoning_text = "".join(reasoning_text_accumulated)
|
||||
# Emit reasoning_text.done event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseReasoningTextDone(
|
||||
content_index=reasoning_content_index,
|
||||
text=final_reasoning_text,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit content_part.done for reasoning
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=reasoning_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartReasoningText(
|
||||
text=final_reasoning_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _emit_refusal_done_events(
|
||||
self,
|
||||
refusal_text_accumulated: list[str],
|
||||
refusal_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
final_refusal_text = "".join(refusal_text_accumulated)
|
||||
# Emit refusal.done event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseRefusalDone(
|
||||
content_index=refusal_content_index,
|
||||
refusal=final_refusal_text,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit content_part.done for refusal
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=refusal_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartRefusal(
|
||||
refusal=final_refusal_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
|
|
@ -257,41 +552,112 @@ class StreamingResponseOrchestrator:
|
|||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
message_item_added_emitted = False
|
||||
content_part_emitted = False
|
||||
reasoning_part_emitted = False
|
||||
refusal_part_emitted = False
|
||||
content_index = 0
|
||||
reasoning_content_index = 1 # reasoning is a separate content part
|
||||
refusal_content_index = 2 # refusal is a separate content part
|
||||
message_output_index = len(output_messages)
|
||||
reasoning_text_accumulated = []
|
||||
refusal_text_accumulated = []
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
|
||||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
# Track deltas for this specific chunk for guardrail validation
|
||||
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit output_item.added for the message on first content
|
||||
if not message_item_added_emitted:
|
||||
message_item_added_emitted = True
|
||||
self.sequence_number += 1
|
||||
message_item = OpenAIResponseMessage(
|
||||
id=message_item_id,
|
||||
content=[],
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=message_item,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
|
||||
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Buffer text delta events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(text_delta_event)
|
||||
else:
|
||||
yield text_delta_event
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Handle reasoning content if present (non-standard field for o1/o3 models)
|
||||
if hasattr(chunk_choice.delta, "reasoning_content") and chunk_choice.delta.reasoning_content:
|
||||
async for event in self._handle_reasoning_content_chunk(
|
||||
reasoning_content=chunk_choice.delta.reasoning_content,
|
||||
reasoning_part_emitted=reasoning_part_emitted,
|
||||
reasoning_content_index=reasoning_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
# Buffer reasoning events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(event)
|
||||
else:
|
||||
yield event
|
||||
reasoning_part_emitted = True
|
||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||
|
||||
# Handle refusal content if present
|
||||
if chunk_choice.delta.refusal:
|
||||
async for event in self._handle_refusal_content_chunk(
|
||||
refusal_content=chunk_choice.delta.refusal,
|
||||
refusal_part_emitted=refusal_part_emitted,
|
||||
refusal_content_index=refusal_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
refusal_part_emitted = True
|
||||
refusal_text_accumulated.append(chunk_choice.delta.refusal)
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
|
|
@ -310,19 +676,22 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
|
||||
# for MCP tools (and even other non-function tools) we emit an output message item later
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
|
|
@ -354,6 +723,22 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Output Safety Validation for this chunk
|
||||
if self.guardrail_ids:
|
||||
# Check guardrails on accumulated text so far
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
chunk_events.clear()
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
else:
|
||||
# No violation detected, emit all content events for this chunk
|
||||
for event in chunk_events:
|
||||
yield event
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
|
|
@ -383,18 +768,66 @@ class StreamingResponseOrchestrator:
|
|||
final_text = "".join(chat_response_content)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit reasoning done events if reasoning content was streamed
|
||||
if reasoning_part_emitted:
|
||||
async for event in self._emit_reasoning_done_events(
|
||||
reasoning_text_accumulated=reasoning_text_accumulated,
|
||||
reasoning_content_index=reasoning_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
|
||||
# Emit refusal done events if refusal content was streamed
|
||||
if refusal_part_emitted:
|
||||
async for event in self._emit_refusal_done_events(
|
||||
refusal_text_accumulated=refusal_text_accumulated,
|
||||
refusal_content_index=refusal_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
# Emit output_item.done for message when we have content and no tool calls
|
||||
if message_item_added_emitted and not chat_response_tool_calls:
|
||||
content_parts = []
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
content_parts.append(
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text=final_text,
|
||||
annotations=[],
|
||||
)
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
message_item = OpenAIResponseMessage(
|
||||
id=message_item_id,
|
||||
content=content_parts,
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=message_item,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
yield ChatCompletionResult(
|
||||
response_id=chat_response_id,
|
||||
content=chat_response_content,
|
||||
|
|
@ -455,6 +888,36 @@ class StreamingResponseOrchestrator:
|
|||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
self.sequence_number += 1
|
||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||
item = OpenAIResponseOutputMessageMCPCall(
|
||||
arguments="",
|
||||
name=tool_call.function.name,
|
||||
id=matching_item_id,
|
||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "web_search":
|
||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "knowledge_search":
|
||||
item = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
queries=[tool_call.function.arguments or ""],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool call: {tool_call.function.name}")
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
|
|
@ -525,7 +988,7 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
async def _process_new_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
|
|
@ -580,7 +1043,6 @@ class StreamingResponseOrchestrator:
|
|||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
|
|
@ -593,14 +1055,22 @@ class StreamingResponseOrchestrator:
|
|||
never_allowed = mcp_tool.allowed_tools.never
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
tool_defs = None
|
||||
list_id = f"mcp_list_{uuid.uuid4()}"
|
||||
attributes = {
|
||||
"server_label": mcp_tool.server_label,
|
||||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
async with tracing.span("list_mcp_tools", attributes):
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
id=list_id,
|
||||
server_label=mcp_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
|
|
@ -634,39 +1104,26 @@ class StreamingResponseOrchestrator:
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_tools(
|
||||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
|
|
@ -701,7 +1158,6 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
|
|
@ -709,3 +1165,64 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _add_mcp_list_tools(
|
||||
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=OpenAIResponseOutputMessageMCPListTools(
|
||||
id=mcp_list_message.id,
|
||||
server_label=mcp_list_message.server_label,
|
||||
tools=[],
|
||||
),
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _reuse_mcp_list_tools(
|
||||
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
for t in original.tools:
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
# convert from input_schema to map of ToolParamDefinitions...
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
# ...then can convert that to openai completions tool
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=original.server_label,
|
||||
tools=original.tools,
|
||||
)
|
||||
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from collections.abc import AsyncIterator
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
|
|
@ -34,6 +37,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
|
@ -89,7 +93,7 @@ class ToolExecutor:
|
|||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
|
|
@ -220,7 +224,13 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
|
@ -235,6 +245,16 @@ class ToolExecutor:
|
|||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
|
|
@ -251,12 +271,18 @@ class ToolExecutor:
|
|||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
attributes = {
|
||||
"server_label": mcp_tool.server_label,
|
||||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
|
|
@ -266,15 +292,20 @@ class ToolExecutor:
|
|||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
async with tracing.span("knowledge_search", {}):
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
attributes = {
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_tool", attributes):
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
|
|
@ -310,7 +341,13 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
|
@ -319,6 +356,7 @@ class ToolExecutor:
|
|||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
item_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
|
|
@ -338,7 +376,7 @@ class ToolExecutor:
|
|||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
|
|
@ -352,14 +390,14 @@ class ToolExecutor:
|
|||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,10 +12,18 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
|
@ -55,6 +63,86 @@ class ChatCompletionResult:
|
|||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""Holds information about tools from this and (if relevant)
|
||||
previous response in order to facilitate reuse of previous
|
||||
listings where appropriate."""
|
||||
|
||||
# tools argument passed into current request:
|
||||
current_tools: list[OpenAIResponseInputTool]
|
||||
# reconstructed map of tool -> mcp server from previous response:
|
||||
previous_tools: dict[str, OpenAIResponseInputToolMCP]
|
||||
# reusable mcp-list-tools objects from previous response:
|
||||
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
|
||||
# tool arguments from current request that still need to be processed:
|
||||
tools_to_process: list[OpenAIResponseInputTool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_tools: list[OpenAIResponseInputTool] | None,
|
||||
):
|
||||
super().__init__(
|
||||
current_tools=current_tools or [],
|
||||
previous_tools={},
|
||||
previous_tool_listings=[],
|
||||
tools_to_process=current_tools or [],
|
||||
)
|
||||
|
||||
def recover_tools_from_previous_response(
|
||||
self,
|
||||
previous_response: OpenAIResponseObject,
|
||||
):
|
||||
"""Determine which mcp_list_tools objects from previous response we can reuse."""
|
||||
|
||||
if self.current_tools and previous_response.tools:
|
||||
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
|
||||
for tool in previous_response.tools:
|
||||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
self.previous_tool_listings = [
|
||||
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
|
||||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
return []
|
||||
|
||||
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
|
||||
if isinstance(tool, OpenAIResponseInputToolWebSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFileSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFunction):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP):
|
||||
return OpenAIResponseToolMCP(
|
||||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
|
|
@ -62,6 +150,7 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
|
|
@ -72,6 +161,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
@ -80,6 +170,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
|
|
@ -96,3 +187,8 @@ class ChatCompletionContext(BaseModel):
|
|||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.tool_context:
|
||||
return []
|
||||
return self.tool_context.available_tools()
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
|
@ -45,10 +47,14 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||
choice: OpenAIChoice,
|
||||
citation_files: dict[str, str] | None = None,
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
|
|
@ -64,7 +70,7 @@ async def convert_chat_choice_to_response_message(
|
|||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
|
|
@ -103,9 +109,13 @@ async def convert_response_content_to_chat_content(
|
|||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
|
|
@ -169,16 +179,53 @@ async def convert_response_input_to_chat_messages(
|
|||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
# Skip user messages that duplicate the last user message in previous_messages
|
||||
# This handles cases where input includes context for function_call_outputs
|
||||
if previous_messages and input_item.role == "user":
|
||||
last_user_msg = None
|
||||
for msg in reversed(previous_messages):
|
||||
if isinstance(msg, OpenAIUserMessageParam):
|
||||
last_user_msg = msg
|
||||
break
|
||||
if last_user_msg:
|
||||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||
for call_id in list(tool_call_results.keys()):
|
||||
if call_id in previous_call_ids:
|
||||
# Valid: this output references a call from previous messages
|
||||
# Add the tool message
|
||||
messages.append(tool_call_results[call_id])
|
||||
del tool_call_results[call_id]
|
||||
|
||||
# If still have unpaired outputs, error
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||
"""Extract all tool_call IDs from messages."""
|
||||
call_ids = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# tool_call is a Pydantic model, use attribute access
|
||||
call_ids.add(tool_call.id)
|
||||
return call_ids
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
|
|
@ -196,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
|
|||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
|
|
@ -263,3 +311,55 @@ def is_function_tool_call(
|
|||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||
if violation_type:
|
||||
message += f" (violation type: {', '.join(violation_type)})"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
if not guardrails:
|
||||
return []
|
||||
|
||||
guardrail_ids = []
|
||||
for guardrail in guardrails:
|
||||
if isinstance(guardrail, str):
|
||||
guardrail_ids.append(guardrail)
|
||||
elif isinstance(guardrail, ResponseGuardrailSpec):
|
||||
guardrail_ids.append(guardrail.type)
|
||||
else:
|
||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import asyncio
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,10 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
|
|
@ -178,9 +181,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
|
@ -425,18 +428,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
|
|
@ -601,7 +609,8 @@ class ReferenceBatchesImpl(Batches):
|
|||
# TODO(SECURITY): review body for security issues
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
|
||||
chat_response = await self.inference_api.openai_chat_completion(chat_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
|
|
@ -614,8 +623,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
elif request.url == "/v1/completions":
|
||||
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
|
||||
completion_response = await self.inference_api.openai_completion(completion_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
|
|
@ -630,6 +640,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
|
||||
)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreConfig = Field(
|
||||
kvstore: KVStoreReference = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
|
|
@ -33,8 +33,8 @@ class ReferenceBatchesImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="batches.db",
|
||||
),
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="batches",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="datasetio::localfs",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="eval",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
|
|
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
|
|||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.openai_completion(
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
|
|
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
|
|||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class LocalfsFilesImplConfig(BaseModel):
|
||||
storage_dir: str = Field(
|
||||
description="Directory to store uploaded files",
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(
|
||||
metadata_store: SqlStoreReference = Field(
|
||||
description="SQL store configuration for file metadata",
|
||||
)
|
||||
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
|
||||
|
|
@ -24,8 +24,8 @@ class LocalfsFilesImplConfig(BaseModel):
|
|||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="files_metadata.db",
|
||||
),
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
|
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
|
|
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
logger.warning(
|
||||
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||
)
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def model_checkpoint_dir(model_id) -> str:
|
|||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
|
||||
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
|
|||
|
|
@ -6,16 +6,16 @@
|
|||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def openai_completion(self, *args, **kwargs):
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
|
|
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
|
|||
|
|
@ -5,17 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -55,11 +54,11 @@ class SentenceTransformersInferenceImpl(
|
|||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
identifier="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_resource_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
|
|
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
|
|
|
|||
|
|
@ -104,9 +104,10 @@ class LoraFinetuningSingleDevice:
|
|||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -101,7 +101,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Code scanner moderation requires a model identifier.")
|
||||
|
||||
inputs = input if isinstance(input, list) else [input]
|
||||
results = []
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,12 @@ from string import Template
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import Inference, Message, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -159,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -169,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
if len(messages) > 0 and messages[0].role != "user":
|
||||
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
|
|
@ -195,14 +200,17 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await impl.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Llama Guard moderation requires a model identifier.")
|
||||
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
|
@ -271,7 +279,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
|
@ -282,7 +290,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
|
@ -290,20 +298,21 @@ class LlamaGuardShield:
|
|||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
|
@ -326,7 +335,7 @@ class LlamaGuardShield:
|
|||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
conversation.append(OpenAIUserMessageParam(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
|
|
@ -335,9 +344,9 @@ class LlamaGuardShield:
|
|||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
return OpenAIUserMessageParam(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
|
@ -370,18 +379,20 @@ class LlamaGuardShield:
|
|||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -65,7 +63,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
||||
|
||||
|
||||
|
|
@ -93,7 +91,7 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestWithExtraBody
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
|
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
}
|
||||
],
|
||||
)
|
||||
judge_response = await self.inference_api.openai_chat_completion(params)
|
||||
content = judge_response.choices[0].message.content
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class TelemetrySink(StrEnum):
|
||||
OTEL_TRACE = "otel_trace"
|
||||
OTEL_METRIC = "otel_metric"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_exporter_otlp_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable.",
|
||||
)
|
||||
service_name: str = Field(
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
default="\u200b",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
default_factory=lambda: (RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
|
||||
description="The path to the SQLite database to use for storing traces",
|
||||
)
|
||||
|
||||
@field_validator("sinks", mode="before")
|
||||
@classmethod
|
||||
def validate_sinks(cls, v):
|
||||
if isinstance(v, str):
|
||||
return [TelemetrySink(sink.strip()) for sink in v.split(",")]
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
|
||||
}
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
def __init__(self, print_attributes: bool = False):
|
||||
self.print_attributes = print_attributes
|
||||
|
||||
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float | None = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
def __init__(self, conn_string):
|
||||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self._local = threading.local() # Thread-local storage for connections
|
||||
self.setup_database()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get a thread-local database connection."""
|
||||
if not hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn = sqlite3.connect(self.conn_string)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to SQLite database: {e}")
|
||||
raise
|
||||
return self._local.conn
|
||||
|
||||
def setup_database(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS traces (
|
||||
trace_id TEXT PRIMARY KEY,
|
||||
service_name TEXT,
|
||||
root_span_id TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS spans (
|
||||
span_id TEXT PRIMARY KEY,
|
||||
trace_id TEXT REFERENCES traces(trace_id),
|
||||
parent_span_id TEXT,
|
||||
name TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
attributes TEXT,
|
||||
status TEXT,
|
||||
kind TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS span_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
span_id TEXT REFERENCES spans(span_id),
|
||||
name TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
attributes TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_traces_created_at
|
||||
ON traces(created_at)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
"""Called when a span starts."""
|
||||
pass
|
||||
|
||||
def on_end(self, span: Span):
|
||||
"""Called when a span ends. Export the span data to SQLite."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||
span_id = format_span_id(span.get_span_context().span_id)
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format_span_id(parent_context.span_id)
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO traces (
|
||||
trace_id, service_name, root_span_id, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(trace_id) DO UPDATE SET
|
||||
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
|
||||
start_time = MIN(excluded.start_time, start_time),
|
||||
end_time = MAX(excluded.end_time, end_time)
|
||||
""",
|
||||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if span.attributes.get(LOCAL_ROOT_SPAN_MARKER) else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
# Insert into spans
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO spans (
|
||||
span_id, trace_id, parent_span_id, name,
|
||||
start_time, end_time, attributes, status,
|
||||
kind
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
),
|
||||
)
|
||||
|
||||
for event in span.events:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO span_events (
|
||||
span_id, name, timestamp, attributes
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, UTC).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(f"Error exporting span to SQLite: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
# We can't access other threads' connections, so we just close our own
|
||||
if hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing SQLite connection: {e}")
|
||||
finally:
|
||||
del self._local.conn
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
pass
|
||||
|
|
@ -1,364 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
MetricEvent,
|
||||
MetricLabelMatcher,
|
||||
MetricQueryType,
|
||||
QueryCondition,
|
||||
QueryMetricsResponse,
|
||||
QuerySpanTreeResponse,
|
||||
QueryTracesResponse,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
global _TRACER_PROVIDER
|
||||
# Initialize the correct span processor based on the provider state.
|
||||
# This is needed since once the span processor is set, it cannot be unset.
|
||||
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
|
||||
# Since the library client can be recreated multiple times in a notebook,
|
||||
# the kernel will hold on to the span processor and cause duplicate spans to be written.
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks or TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
if self.config.otel_exporter_otlp_endpoint is None:
|
||||
raise ValueError(
|
||||
"otel_exporter_otlp_endpoint is required when OTEL_TRACE or OTEL_METRIC is enabled"
|
||||
)
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: int,
|
||||
end_time: int | None = None,
|
||||
granularity: str | None = None,
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
"""Query metrics from the telemetry store.
|
||||
|
||||
Args:
|
||||
metric_name: The name of the metric to query (e.g., "prompt_tokens")
|
||||
start_time: Start time as Unix timestamp
|
||||
end_time: End time as Unix timestamp (defaults to now if None)
|
||||
granularity: Time granularity for aggregation
|
||||
query_type: Type of query (RANGE or INSTANT)
|
||||
label_matchers: Label filters to apply
|
||||
|
||||
Returns:
|
||||
QueryMetricsResponse with metric time series data
|
||||
"""
|
||||
# Convert timestamps to datetime objects
|
||||
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
|
||||
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
|
||||
|
||||
# Use SQLite trace store if available
|
||||
if hasattr(self, "trace_store") and self.trace_store:
|
||||
return await self.trace_store.query_metrics(
|
||||
metric_name=metric_name,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
granularity=granularity,
|
||||
query_type=query_type,
|
||||
label_matchers=label_matchers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
|
||||
)
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type.value,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse:
|
||||
return QueryTracesResponse(
|
||||
data=await self.trace_store.query_traces(
|
||||
attribute_filters=attribute_filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
return await self.trace_store.get_trace(trace_id)
|
||||
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
return await self.trace_store.get_span(trace_id, span_id)
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse:
|
||||
return QuerySpanTreeResponse(
|
||||
data=await self.trace_store.get_span_tree(
|
||||
span_id=span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
)
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
|
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
|
|||
|
||||
model = config.model
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
response = await inference_api.openai_chat_completion(params)
|
||||
|
||||
query = response.choices[0].message.content
|
||||
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
return RAGQueryResult(
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ from .config import ChromaVectorIOConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
@ -23,8 +23,8 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::chroma",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,22 +8,19 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FaissVectorIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
persistence: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="faiss_store.db",
|
||||
)
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::faiss",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,33 +17,21 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
|
|
@ -154,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -174,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
|
@ -198,28 +176,28 @@ class FaissIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
# Load existing banks from kvstore
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -244,45 +222,33 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [i.vector_store for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
|
|
@ -290,10 +256,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,6 @@ from .config import MilvusVectorIOConfig
|
|||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,25 +8,22 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::milvus",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
|||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
files_api = deps.get(Api.files)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -9,23 +9,21 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
kvstore: KVStoreConfig
|
||||
persistence: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::qdrant",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,22 +8,19 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str = Field(description="Path to the SQLite database file")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::sqlite_vec",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,14 +17,10 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -32,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
|
|
@ -45,7 +41,7 @@ HYBRID_SEARCH = "hybrid"
|
|||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:sqlite_vec:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
|
||||
|
|
@ -174,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
(
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||
|
||||
connection.commit()
|
||||
|
||||
|
|
@ -215,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
|
|
@ -260,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
|
|
@ -402,33 +374,32 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""
|
||||
A VectorIO implementation using SQLite + sqlite_vec.
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_stores`)
|
||||
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -437,67 +408,64 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [v.vector_store for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=SQLiteVecIndex(
|
||||
dimension=vector_db.embedding_dimension,
|
||||
dimension=vector_store.embedding_dimension,
|
||||
db_path=self.config.db_path,
|
||||
bank_id=vector_db.identifier,
|
||||
bank_id=vector_store.identifier,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index's add_chunks.
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,12 +32,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.inference,
|
||||
Api.safety,
|
||||
Api.vector_io,
|
||||
Api.vector_dbs,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.telemetry,
|
||||
Api.conversations,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -43,6 +43,12 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
# required by some SentenceTransformers architectures for tensor rearrange/merge ops
|
||||
"einops",
|
||||
# fast HF tokenization backend used by SentenceTransformers models
|
||||
"tokenizers",
|
||||
# safe and fast file format for storing and loading tensors
|
||||
"safetensors",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
|
|
@ -275,7 +281,7 @@ Available Models:
|
|||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.telemetry,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
],
|
||||
optional_api_dependencies=[Api.datasetio],
|
||||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
description="Meta's reference implementation of telemetry and observability using OpenTelemetry.",
|
||||
),
|
||||
]
|
||||
|
|
@ -26,7 +26,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="Meta's reference implementation of a vector database.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
|
@ -36,7 +36,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -89,7 +89,7 @@ more details about Faiss in general.
|
|||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within an SQLite database.
|
||||
|
|
@ -297,7 +297,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
|
|
@ -310,7 +310,7 @@ Please refer to the sqlite-vec provider documentation.
|
|||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
|
@ -352,7 +352,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
|
@ -396,7 +396,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -508,7 +508,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
|
|
@ -548,7 +548,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description=r"""
|
||||
[Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -601,7 +601,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
|
|
@ -614,7 +614,7 @@ Please refer to the inline provider documentation.
|
|||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
|
|
@ -820,7 +820,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the remote provider documentation.
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="datasetio::huggingface",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service.
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
|
|
@ -24,7 +24,7 @@ class S3FilesImplConfig(BaseModel):
|
|||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
|
@ -35,8 +35,8 @@ class S3FilesImplConfig(BaseModel):
|
|||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="s3_files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
|
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
|
|
|||
|
|
@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,21 +6,21 @@
|
|||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
|
@ -125,66 +125,18 @@ class BedrockInferenceAdapter(
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
|
|
|
|||
|
|
@ -25,8 +25,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None), # type: ignore[arg-type]
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.inference import OpenAICompletion
|
||||
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -29,9 +28,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
|
|
@ -45,25 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
|
|
|||
|
|
@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
|
@ -14,11 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override embeddings method to handle Gemini's missing usage statistics.
|
||||
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Gemini doesn't return usage statistics - use default values
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
else:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
|
|
|
|||
|
|
@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
|
|
|||
|
|
@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
|
|
|
|||
|
|
@ -3,9 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
|
@ -21,9 +25,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
Llama API Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
@ -34,35 +35,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ This provider enables running inference using NVIDIA NIM.
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
@ -45,7 +45,7 @@ The following example shows how to create a chat completion for an NVIDIA NIM.
|
|||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -67,37 +67,40 @@ print(f"Response: {response.choices[0].message.content}")
|
|||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
tool_definition = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get current weather information for a location",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
required=True,
|
||||
),
|
||||
"unit": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="Temperature unit (celsius or fahrenheit)",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
tool_definition = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather information for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit (celsius or fahrenheit)",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
tool_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.choices[0].message.content}")
|
||||
print(f"Response content: {tool_response.choices[0].message.content}")
|
||||
if tool_response.choices[0].message.tool_calls:
|
||||
for tool_call in tool_response.choices[0].message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
print(f"Tool Called: {tool_call.function.name}")
|
||||
print(f"Arguments: {tool_call.function.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
|
@ -105,33 +108,26 @@ if tool_response.choices[0].message.tool_calls:
|
|||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"age": {"type": "number"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
response_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
response_format=response_format,
|
||||
extra_body={"nvext": {"guided_json": person_schema}},
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
|
|
@ -139,16 +135,13 @@ print(f"Structured Response: {structured_response.choices[0].message.content}")
|
|||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
contents=["What is the capital of France?"],
|
||||
task_type="query",
|
||||
response = client.embeddings.create(
|
||||
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
input=["What is the capital of France?"],
|
||||
extra_body={"input_type": "query"},
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
print(f"Embeddings: {response.data}")
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
|
@ -166,15 +159,15 @@ image_path = {path_to_the_image}
|
|||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.chat.completions.create(
|
||||
model="nvidia/vila",
|
||||
model="nvidia/meta/llama-3.2-11b-vision-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": demo_image_b64,
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{demo_image_b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .config import NVIDIAConfig
|
|||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
|
||||
from .nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
|
|
|
|||
|
|
@ -5,13 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -28,15 +21,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||
must come before Inference to ensure that OpenAIMixin methods are available
|
||||
in the Inference interface.
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
|
|
@ -51,7 +35,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.api_key:
|
||||
if not self.config.auth_credential:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
|
|
@ -62,7 +46,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
if not _is_nvidia_hosted(self.config):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
return None
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -71,54 +61,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,16 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
|
|
|||
|
|
@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
|||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
|
@ -70,120 +70,37 @@ class PassthroughInferenceAdapter(Inference):
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
return await client.inference.openai_completion(**params)
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await client.inference.openai_completion(**request_params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
return await client.inference.openai_chat_completion(**params)
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await client.inference.openai_chat_completion(**request_params)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
||||
json_params = {}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -25,8 +25,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -25,66 +26,18 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
config: RunpodImplConfig
|
||||
provider_data_api_key_field: str = "runpod_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
if stream and not stream_options:
|
||||
stream_options = {"include_usage": True}
|
||||
params = params.model_copy()
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.stream and not params.stream_options:
|
||||
params.stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(params)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
|||
SambaNova Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from collections.abc import Iterable
|
|||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -30,7 +33,7 @@ class _HfAdapter(OpenAIMixin):
|
|||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
|
@ -40,11 +43,7 @@ class _HfAdapter(OpenAIMixin):
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from together import AsyncTogether
|
|||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
|
|
@ -39,15 +40,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
|
|
@ -65,11 +63,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Together's OpenAI-compatible embeddings endpoint is not compatible with
|
||||
|
|
@ -81,25 +75,27 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
||||
"""
|
||||
# Together support ticket #13332 -> will not fix
|
||||
if user is not None:
|
||||
if params.user is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support user param.")
|
||||
# Together support ticket #13333 -> escalated
|
||||
if dimensions is not None:
|
||||
if params.dimensions is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
input=params.input,
|
||||
encoding_format=params.encoding_format,
|
||||
)
|
||||
|
||||
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
|
||||
response.model = (
|
||||
params.model
|
||||
) # return the user the same model id they provided, avoid exposing the provider model id
|
||||
|
||||
# Together support ticket #13330 -> escalated
|
||||
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
|
||||
if not hasattr(response, "usage") or response.usage is None:
|
||||
logger.warning(
|
||||
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
|
||||
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
||||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
default="fake",
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
|
@ -15,8 +14,7 @@ from pydantic import ConfigDict
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -38,8 +36,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token or ""
|
||||
def get_api_key(self) -> str | None:
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
|
|
@ -77,63 +77,35 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Skip the check when running without authentication.
|
||||
"""
|
||||
if not self.config.auth_credential:
|
||||
model_ids = []
|
||||
async for m in self.client.models.list():
|
||||
if m.id == model: # Found exact match
|
||||
return True
|
||||
model_ids.append(m.id)
|
||||
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
|
||||
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
|
||||
return True
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
max_tokens = max_tokens or self.config.max_tokens
|
||||
params = params.model_copy()
|
||||
|
||||
# Apply vLLM-specific defaults
|
||||
if params.max_tokens is None and self.config.max_tokens:
|
||||
params.max_tokens = self.config.max_tokens
|
||||
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# * https://github.com/vllm-project/vllm/pull/10000
|
||||
if not tools and tool_choice is not None:
|
||||
tool_choice = ToolChoice.none.value
|
||||
if not params.tools and params.tool_choice is not None:
|
||||
params.tool_choice = ToolChoice.none.value
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await super().openai_chat_completion(params)
|
||||
|
|
|
|||
|
|
@ -7,18 +7,18 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
watsonx_project_id: str | None = Field(
|
||||
default=None,
|
||||
description="IBM WatsonX project ID",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
watsonx_api_key: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
# This seems like it should be required, but none of the other remote inference
|
||||
# providers require it, so this is optional here too for consistency.
|
||||
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai API key",
|
||||
)
|
||||
# As above, this is optional here too for consistency.
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
|
|
|
|||
|
|
@ -4,42 +4,259 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "watsonx_api_key"
|
||||
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
api_key_from_config=api_key,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
openai_compat_api_base=self.get_base_url(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Override parent method to add timeout and inject usage object when missing.
|
||||
This works around a LiteLLM defect where usage block is sometimes dropped.
|
||||
"""
|
||||
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
result = await litellm.acompletion(**request_params)
|
||||
|
||||
# If not streaming, check and inject usage if missing
|
||||
if not params.stream:
|
||||
# Use getattr to safely handle cases where usage attribute might not exist
|
||||
if getattr(result, "usage", None) is None:
|
||||
# Create usage object with zeros
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
# Use model_copy to create a new response with the usage injected
|
||||
result = result.model_copy(update={"usage": usage_obj})
|
||||
return result
|
||||
|
||||
# For streaming, wrap the iterator to normalize chunks
|
||||
return self._normalize_stream(result)
|
||||
|
||||
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
|
||||
"""
|
||||
Normalize a chunk to ensure it has all expected attributes.
|
||||
This works around LiteLLM not always including all expected attributes.
|
||||
"""
|
||||
# Ensure chunk has usage attribute with zeros if missing
|
||||
if not hasattr(chunk, "usage") or chunk.usage is None:
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
chunk = chunk.model_copy(update={"usage": usage_obj})
|
||||
|
||||
# Ensure all delta objects in choices have expected attributes
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
normalized_choices = []
|
||||
for choice in chunk.choices:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
delta = choice.delta
|
||||
# Build update dict for missing attributes
|
||||
delta_updates = {}
|
||||
if not hasattr(delta, "refusal"):
|
||||
delta_updates["refusal"] = None
|
||||
if not hasattr(delta, "reasoning_content"):
|
||||
delta_updates["reasoning_content"] = None
|
||||
|
||||
# If we need to update delta, create a new choice with updated delta
|
||||
if delta_updates:
|
||||
new_delta = delta.model_copy(update=delta_updates)
|
||||
new_choice = choice.model_copy(update={"delta": new_delta})
|
||||
normalized_choices.append(new_choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
|
||||
# If we modified any choices, create a new chunk with updated choices
|
||||
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
|
||||
chunk = chunk.model_copy(update={"choices": normalized_choices})
|
||||
|
||||
return chunk
|
||||
|
||||
async def _normalize_stream(
|
||||
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Normalize all chunks in the stream to ensure they have expected attributes.
|
||||
This works around LiteLLM sometimes not including expected attributes.
|
||||
"""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
# Normalize and yield each chunk immediately
|
||||
yield self._normalize_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
||||
# Call litellm embedding function with watsonx-specific parameters
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
dimensions=params.dimensions,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
|
||||
|
||||
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add watsonx.ai specific parameters
|
||||
params["project_id"] = self.config.project_id
|
||||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -56,7 +56,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ This provider enables safety checks and guardrails for LLM interactions using NV
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -8,12 +8,11 @@ from typing import Any
|
|||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
|
|
@ -44,7 +43,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
|
@ -67,7 +66,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||
|
||||
|
||||
|
|
@ -118,7 +117,7 @@ class NeMoGuardrails:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
|
|
@ -132,10 +131,9 @@ class NeMoGuardrails:
|
|||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": request_messages,
|
||||
"messages": [{"role": message.role, "content": message.content} for message in messages],
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
|
|
@ -72,7 +70,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = litellm.completion(
|
||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||
)
|
||||
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
|
||||
shield_message = response.choices[0].message.content
|
||||
|
||||
if "unsafe" in shield_message.lower():
|
||||
|
|
|
|||
|
|
@ -12,24 +12,16 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
@ -38,7 +30,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
|
|||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||
|
|
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
|
|
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
@ -133,11 +114,11 @@ class ChromaIndex(EmbeddingIndex):
|
|||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
|
@ -146,11 +127,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.vector_db_store = self.kvstore
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
self.vector_store_table = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
|
|
@ -170,70 +151,58 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db, ChromaIndex(self.client, collection), self.inference_api
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_store_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Chroma vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,21 +8,21 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
url: str | None
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_remote_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::chroma_remote",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
|
|||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
|
@ -28,8 +28,8 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::milvus_remote",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,16 +12,12 @@ from numpy.typing import NDArray
|
|||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -30,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||
|
||||
|
|
@ -39,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
|
||||
|
|
@ -73,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
|
@ -143,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
|
@ -166,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
|
|
@ -209,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
|
|
@ -302,7 +261,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
|
|
@ -314,28 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
collection_name=vector_store.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
|
|
@ -352,72 +311,61 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a milvus vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue