mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge remote-tracking branch 'upstream/main' into feat/gunicorn-production-server
This commit is contained in:
commit
b060f73e6d
70 changed files with 46290 additions and 1133 deletions
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
|
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
|
|||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
|
|
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
queries: list[str]
|
||||
queries: Sequence[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
id: str
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[OpenAIResponseOutput]
|
||||
output: Sequence[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: str | None = None
|
||||
prompt: OpenAIResponsePrompt | None = None
|
||||
|
|
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: list[OpenAIResponseTool] | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
|
|
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseInput]
|
||||
data: Sequence[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
|
|
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
:param input: List of input items that led to this response
|
||||
"""
|
||||
|
||||
input: list[OpenAIResponseInput]
|
||||
input: Sequence[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
|
|
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseObjectWithInput]
|
||||
data: Sequence[OpenAIResponseObjectWithInput]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
|
|
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
|
|||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
|
@ -61,38 +59,19 @@ class Chunk(BaseModel):
|
|||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
|
||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||
The `chunk_metadata` is required backend functionality.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
chunk_id: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
# Extract chunk_id from metadata if present
|
||||
if self.metadata and "chunk_id" in self.metadata:
|
||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
||||
|
||||
@property
|
||||
def chunk_id(self) -> str:
|
||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
||||
if self.stored_chunk_id:
|
||||
return self.stored_chunk_id
|
||||
|
||||
if "document_id" in self.metadata:
|
||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
||||
|
||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
|
|||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
|
@ -42,11 +44,90 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||
"""
|
||||
Fetch models from providers that have credentials in the current request's provider_data.
|
||||
|
||||
This allows users to see models available to them from providers that require
|
||||
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||
|
||||
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||
cache them in the registry since they are user-specific.
|
||||
"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return []
|
||||
|
||||
dynamic_models = []
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
if not isinstance(provider, NeedsRequestProviderData):
|
||||
continue
|
||||
|
||||
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
|
||||
spec = getattr(provider, "__provider_spec__", None)
|
||||
if not spec or not getattr(spec, "provider_data_validator", None):
|
||||
continue
|
||||
|
||||
# Validate provider_data silently - we're speculatively checking all providers
|
||||
# so validation failures are expected when user didn't provide keys for this provider
|
||||
try:
|
||||
validator = instantiate_class_type(spec.provider_data_validator)
|
||||
validator(**provider_data)
|
||||
except Exception:
|
||||
# User didn't provide credentials for this provider - skip silently
|
||||
continue
|
||||
|
||||
# Validation succeeded! User has credentials for this provider
|
||||
# Now try to list models
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
continue
|
||||
|
||||
return dynamic_models
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
all_models = registry_models + unique_dynamic_models
|
||||
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
|
|
@ -54,7 +135,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
)
|
||||
for model in models
|
||||
for model in all_models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -63,6 +64,7 @@ class LlamaStack(
|
|||
Providers,
|
||||
Inference,
|
||||
Agents,
|
||||
Batches,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
|
||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(msg)
|
||||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
|
|
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||
if step.violation and step.violation.user_message:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
|
|
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
|
||||
if is_resume:
|
||||
assert isinstance(request, AgentTurnResumeRequest)
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
|
|
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
now_dt = datetime.now(UTC)
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
completed_at=now_dt,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn.input_messages
|
||||
# Cast needed due to list invariance - last_turn.input_messages is the right type
|
||||
input_messages = last_turn.input_messages # type: ignore[assignment]
|
||||
|
||||
turn_id = request.turn_id
|
||||
actual_turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
assert isinstance(request, AgentTurnCreateRequest)
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now(UTC).isoformat()
|
||||
input_messages = request.messages
|
||||
start_time = datetime.now(UTC)
|
||||
# Cast needed due to list invariance - request.messages is the right type
|
||||
input_messages = request.messages # type: ignore[assignment]
|
||||
# Use the generated turn_id from beginning of function
|
||||
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
|
||||
|
||||
output_message = None
|
||||
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||
req_sampling = (
|
||||
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||
)
|
||||
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
sampling_params=req_sampling,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
documents=req_documents,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
|
||||
event.payload, "step_details"
|
||||
):
|
||||
step_details = event.payload.step_details
|
||||
steps.append(step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
input_messages=input_messages, # type: ignore[arg-type]
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
if len(self.input_shields) > 0:
|
||||
if self.input_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
if len(self.output_shields) > 0:
|
||||
if self.output_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
|
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
||||
shield_call_start_time = datetime.now(UTC)
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
|
|
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||
|
||||
output_attachments = []
|
||||
output_attachments: list[Attachment] = []
|
||||
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
inference_start_time = datetime.now(UTC)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
|
|
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
response_format=self.agent_config.response_format, # type: ignore[arg-type]
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
openai_stream, # type: ignore[arg-type]
|
||||
enable_incremental_tool_calls=True,
|
||||
)
|
||||
|
||||
async for chunk in response_stream:
|
||||
|
|
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
"tool_calls": [
|
||||
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
|
||||
],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
|
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if tool_calls:
|
||||
content = ""
|
||||
|
||||
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
|
||||
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
|
||||
message = CompletionMessage(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls if valid_tool_calls else None,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
# somewhere deep, we are re-assigning message or closing over some
|
||||
|
|
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
|
||||
if n_iter >= max_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
|
|
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
if not message.tool_calls or len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content += output_attachments # type: ignore[arg-type]
|
||||
else:
|
||||
message.content = [message.content] + output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content = [message.content] + output_attachments # type: ignore[assignment]
|
||||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
|
|
@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
# Process non-client tool calls first
|
||||
for tool_call in non_client_tool_calls:
|
||||
|
|
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
|
|
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_execution_start_time = datetime.now(UTC)
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
|
|
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
|
|
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
started_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
if self.agent_config.client_tools:
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
|
|
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
# Convert BuiltinTool to string for dictionary key
|
||||
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
|
||||
if tool_name_to_def.get(identifier_str, None):
|
||||
raise ValueError(f"Tool {identifier_str} already exists")
|
||||
tool_name_to_def[identifier_str] = ToolDefinition(
|
||||
tool_name=identifier_str,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
|
@ -966,14 +995,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
result = cast(
|
||||
ToolInvocationResult,
|
||||
await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
),
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
|
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
prompt,
|
||||
|
|
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters,
|
||||
guardrails,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
|
|
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> None:
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResource:
|
||||
"""Concrete implementation of ProtectedResource for session access control."""
|
||||
|
||||
type: str
|
||||
identifier: str
|
||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
|
|
@ -53,8 +64,15 @@ class AgentPersistence:
|
|||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
# Only perform access control if we have an authenticated user
|
||||
if user is not None and session_info.identifier is not None:
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=user,
|
||||
)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
@ -62,7 +80,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
|
@ -83,7 +101,22 @@ class AgentPersistence:
|
|||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
# Get current user - if None, skip access control (e.g., in tests)
|
||||
user = get_authenticated_user()
|
||||
if user is None:
|
||||
return True
|
||||
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=session_info.owner,
|
||||
)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
|||
|
|
@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
|
|||
input: str | list[OpenAIResponseInput],
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
# Convert Sequence to list for mutation
|
||||
new_input_items = list(previous_response.input)
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
|
|
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
|
|
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
|
|||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
|
||||
input_items_data: list[OpenAIResponseInput] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
|
|
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
|
|||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
|
|
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -289,16 +292,19 @@ class OpenAIResponsesImpl:
|
|||
failed_response = None
|
||||
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
case _:
|
||||
pass # Other event types don't have .response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
|
|
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
assert text is not None, "text must not be None"
|
||||
assert max_infer_iters is not None, "max_infer_iters must not be None"
|
||||
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id, conversation
|
||||
|
|
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
|
|||
final_response = None
|
||||
failed_response = None
|
||||
|
||||
output_items = []
|
||||
# Type as ConversationItem to avoid list invariance issues
|
||||
output_items: list[ConversationItem] = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
final_response = stream_chunk.response
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
case "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
case _:
|
||||
pass # Other event types
|
||||
|
||||
# Store and sync before yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||
|
|
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
|
|||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
# Type as ConversationItem union to avoid list invariance issues
|
||||
conversation_items: list[ConversationItem] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str,
|
||||
instructions: str | None,
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
|||
self.prompt = prompt
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
)
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
|
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
|
|||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
has_tool_calls = (
|
||||
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||
and choice.message.tool_calls
|
||||
and self.ctx.response_tools
|
||||
)
|
||||
if not has_tool_calls:
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
|
|
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
if not is_new_tool_call and response_tool_call is not None:
|
||||
# Both should have functions since we're inside the tool_call.function check above
|
||||
assert response_tool_call.function is not None
|
||||
assert tool_call.function is not None
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
|
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
|
|||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
# Ensure that arguments, if sent back to the inference provider, are not None
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
if tool_call.function:
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = tool_call.function.arguments
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
|
||||
func = chat_response_tool_calls[tool_call_index].function
|
||||
|
||||
tool_call_name = func.name if func else ""
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
|
|
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
self.sequence_number += 1
|
||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||
item = OpenAIResponseOutputMessageMCPCall(
|
||||
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
|
||||
arguments="",
|
||||
name=tool_call.function.name,
|
||||
id=matching_item_id,
|
||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "web_search":
|
||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
|
|||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
|
|
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
|
|
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
|
|||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = None
|
||||
|
|
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
|
|
@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator:
|
|||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
yield stream_event
|
||||
# tool_context can be None when no tools are provided in the response request
|
||||
if self.ctx.tool_context:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(
|
||||
self.ctx.tool_context.tools_to_process, output_messages
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
|
|
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -67,7 +69,7 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
|
|
@ -84,7 +86,16 @@ class ToolExecutor:
|
|||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
has_error = bool(
|
||||
error_exc
|
||||
or (
|
||||
result
|
||||
and (
|
||||
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||
or getattr(result, "error_message", None)
|
||||
)
|
||||
)
|
||||
)
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
|
|
@ -101,7 +112,9 @@ class ToolExecutor:
|
|||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
citation_files=(
|
||||
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
||||
),
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -188,8 +201,9 @@ class ToolExecutor:
|
|||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
# Cast to proper InterleavedContent type (list invariance)
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
content=content_items, # type: ignore[arg-type]
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
|
|
@ -209,51 +223,60 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
|
|
@ -261,7 +284,7 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
) -> tuple[Exception | None, Any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
|
@ -284,9 +307,13 @@ class ToolExecutor:
|
|||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
response_file_search_tool = (
|
||||
next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if ctx.response_tools
|
||||
else None
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
|
|
@ -322,35 +349,34 @@ class ToolExecutor:
|
|||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
|
|
@ -360,21 +386,18 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
result: Any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
message: Any
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
|
|
@ -383,10 +406,14 @@ class ToolExecutor:
|
|||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
||||
result and getattr(result, "error_message", None)
|
||||
):
|
||||
ec = getattr(result, "error_code", "unknown")
|
||||
em = getattr(result, "error_message", "")
|
||||
message.error = f"Error (code {ec}): {em}"
|
||||
elif result and (content := getattr(result, "content", None)):
|
||||
message.output = interleaved_content_as_str(content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -401,17 +428,17 @@ class ToolExecutor:
|
|||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||
score = metadata["scores"][i] if "scores" in metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
text=text if text is not None else "",
|
||||
score=score if score is not None else 0.0,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
|
|
@ -421,27 +448,32 @@ class ToolExecutor:
|
|||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
input_message: OpenAIToolMessageParam | None = None
|
||||
if result and (result_content := getattr(result, "content", None)):
|
||||
# all the mypy contortions here are still unsatisfactory with random Any typing
|
||||
if isinstance(result_content, str):
|
||||
msg_content: str | list[Any] = result_content
|
||||
elif isinstance(result_content, list):
|
||||
content_list: list[Any] = []
|
||||
for item in result_content:
|
||||
part: Any
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
url_value = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
url_value = str(item.image.url) if item.image.url else ""
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
content_list.append(part)
|
||||
msg_content = content_list
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||
# This is runtime-safe as the API accepts it, but mypy complains
|
||||
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
|
|||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
tools_to_process: list[OpenAIResponseInputTool] = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
|
||||
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
|
||||
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
|
|
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
|
|||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
for mcp_tool in listing.tools:
|
||||
# mcp_tool is MCPListToolsTool which has a name: str field
|
||||
self.previous_tools[mcp_tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
|
|
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
|
|||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
# Exhaustive check - all tool types should be handled above
|
||||
raise AssertionError(f"Unexpected tool type: {type(tool)}")
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
|
|||
|
||||
return OpenAIResponseMessage(
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
|
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
|
|||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
# Type with union to avoid list invariance issues
|
||||
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
|
|
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
# Output can be None, use empty string as fallback
|
||||
output_content = input_item.output if input_item.output is not None else ""
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
content=output_content,
|
||||
tool_call_id=input_item.id,
|
||||
)
|
||||
)
|
||||
|
|
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
|
|||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
elif isinstance(input_item, OpenAIResponseMessage):
|
||||
# Narrow type to OpenAIResponseMessage which has content and role attributes
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
|
|
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
|
|||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
# Dynamic message type call - different message types have different content expectations
|
||||
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
|
||||
if len(tool_call_results):
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
|
|
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
|
|||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
# Assert name exists for json_schema format
|
||||
assert text.format.get("name"), "json_schema format requires a name"
|
||||
schema_name: str = text.format["name"] # type: ignore[assignment]
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
|
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
|
|||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
|
|
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
|
|
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
flagged_categories = (
|
||||
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||
)
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
|
|
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
return message
|
||||
|
||||
# No violations found
|
||||
return None
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
|||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
|||
|
|
@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]
|
||||
|
|
|
|||
|
|
@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Filter out None values from endpoint names
|
||||
api_token = self._get_api_key_from_config_or_provider_data()
|
||||
return [
|
||||
endpoint.name # type: ignore[misc]
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
host=self.config.url, token=api_token
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
|
|||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
from typing import Any
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||
|
|
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
|
|||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
# Convert to dict for manipulation
|
||||
fmt_dict = dict(fmt.json_schema)
|
||||
name = fmt_dict["title"]
|
||||
del fmt_dict["title"]
|
||||
fmt_dict["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = self._add_additional_properties_recursive(fmt)
|
||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"schema": fmt_dict,
|
||||
"strict": self.json_schema_strict,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
@ -176,10 +175,10 @@ class LiteLLMOpenAIMixin(
|
|||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
|
||||
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
|
||||
|
||||
api_key = self.api_key_from_config
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
|
|
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
|
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
|
|||
# Call litellm embedding function
|
||||
# litellm.drop_params = True
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
|
|
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
model=provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
|
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
|
|||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -161,8 +161,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
|||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
options["temperature"] = params.strategy.temperature
|
||||
options["top_p"] = params.strategy.top_p
|
||||
if params.strategy.temperature is not None:
|
||||
options["temperature"] = params.strategy.temperature
|
||||
if params.strategy.top_p is not None:
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
|
|
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
|
|||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
if hasattr(choice, "message"):
|
||||
return choice.message.content
|
||||
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
return choice.text
|
||||
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
|
||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||
|
|
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
|
|||
) -> list[TokenLogProbs] | None:
|
||||
if not logprobs:
|
||||
return None
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
# Together supports logprobs with top_k=1 only. This means for each token position,
|
||||
|
|
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
|
|||
if isinstance(logprobs, float):
|
||||
# Adapt response from Together CompletionChoicesChunk
|
||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
return None
|
||||
|
||||
|
|
@ -245,23 +247,24 @@ def process_completion_response(
|
|||
response: OpenAICompatCompletionResponse,
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
text = choice.text or ""
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
if text.endswith("<|eot_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
content=text[: -len("<|eot_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
if text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
content=text[: -len("<|eom_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
|
||||
content=text,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
|
@ -272,10 +275,10 @@ def process_chat_completion_response(
|
|||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
if not choice.message or not choice.message.tool_calls:
|
||||
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
raise ValueError("Tool calls are not present in the response")
|
||||
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
|
|
@ -287,9 +290,11 @@ def process_chat_completion_response(
|
|||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
# Filter to only valid ToolCall objects
|
||||
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
# Content is not optional
|
||||
content="",
|
||||
|
|
@ -299,7 +304,7 @@ def process_chat_completion_response(
|
|||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
|
@ -324,8 +329,8 @@ def process_chat_completion_response(
|
|||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
|
||||
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=None,
|
||||
|
|
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
)
|
||||
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
request_tools = {t.tool_name: t for t in (request.tools or [])}
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in request_tools:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
|
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
}
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
result["tool_calls"] = []
|
||||
tool_calls_list = []
|
||||
for tc in message.tool_calls:
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
|
|
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
result["tool_calls"].append(
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
|
|
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
},
|
||||
}
|
||||
)
|
||||
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
)
|
||||
elif isinstance(content_, list):
|
||||
return [await impl(item) for item in content_]
|
||||
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content_)}")
|
||||
|
||||
|
|
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
|
|||
else:
|
||||
return [ret]
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
out: OpenAIChatCompletionMessage
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
|
|
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
for tool in (message.tool_calls or [])
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
|
|
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
|
|||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
**params,
|
||||
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
|
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function["name"] = tool.tool_name.value
|
||||
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
else:
|
||||
function["name"] = tool.tool_name
|
||||
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.description:
|
||||
function["description"] = tool.description
|
||||
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.input_schema:
|
||||
# Pass through the entire JSON Schema as-is
|
||||
function["parameters"] = tool.input_schema
|
||||
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
# NOTE: OpenAI does not support output_schema, so we drop it here
|
||||
# It's stored in LlamaStack for validation and other provider usage
|
||||
|
|
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
|
|||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
tool_choice = ToolChoice(tool_choice)
|
||||
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
|
||||
except ValueError:
|
||||
pass
|
||||
tool_config.tool_choice = tool_choice
|
||||
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||
lls_tools = []
|
||||
lls_tools: list[ToolDefinition] = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
|
|
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
|
|||
|
||||
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
|
||||
if response_format_dict.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
|
||||
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
|
|||
|
||||
# Map an explicit temperature of 0 to greedy sampling
|
||||
if temperature == 0:
|
||||
strategy = GreedySamplingStrategy()
|
||||
sampling_params.strategy = GreedySamplingStrategy()
|
||||
else:
|
||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||
if temperature is None:
|
||||
temperature = 1.0
|
||||
if top_p is None:
|
||||
top_p = 1.0
|
||||
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
||||
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
|
||||
|
||||
sampling_params.strategy = strategy
|
||||
return sampling_params
|
||||
|
||||
|
||||
|
|
@ -957,23 +962,24 @@ def openai_messages_to_messages(
|
|||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
converted_messages = []
|
||||
converted_messages: list[Message] = []
|
||||
for message in messages:
|
||||
converted_message: Message
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=openai_content_to_content(message.content),
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
elif message.role == "tool":
|
||||
converted_message = ToolResponseMessage(
|
||||
role="tool",
|
||||
call_id=message.tool_call_id,
|
||||
content=openai_content_to_content(message.content),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role {message.role}")
|
||||
|
|
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
|
|||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text)
|
||||
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
|
|
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
|
|||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
|
|||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
|
||||
logprobs = getattr(choice, "logprobs", None)
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
|
|
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0],
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -1125,12 +1131,15 @@ async def convert_openai_chat_completion_stream(
|
|||
if tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
delta = f"{buffer['name']}("
|
||||
buffer["content"] += delta
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
|
||||
if tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
buffer["arguments"] += delta
|
||||
buffer["content"] += delta
|
||||
if buffer["arguments"] is not None and delta:
|
||||
buffer["arguments"] += delta
|
||||
if buffer["content"] is not None and delta:
|
||||
buffer["content"] += delta
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
|
|||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
elif choice.delta.content:
|
||||
|
|
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1155,7 +1164,8 @@ async def convert_openai_chat_completion_stream(
|
|||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||
if buffer["name"]:
|
||||
delta = ")"
|
||||
buffer["content"] += delta
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
|
|
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=buffer["arguments"],
|
||||
parsed_tool_call = ToolCall(
|
||||
call_id=buffer["call_id"] or "",
|
||||
tool_name=buffer["name"] or "",
|
||||
arguments=buffer["arguments"] or "",
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=buffer["content"],
|
||||
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
messages = openai_messages_to_messages(messages)
|
||||
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
|
|
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
|
|
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
response = await outstanding_response
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
finish_reason = (
|
||||
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
|
||||
)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
|
|
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
|
||||
# First chunk includes full structure
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
name=tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
|
|
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
choices = []
|
||||
choices: list[OpenAIChatCompletionChoice] = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
|
|
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
|
|
|
|||
|
|
@ -196,6 +196,7 @@ def make_overlapped_chunks(
|
|||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk,
|
||||
chunk_id=chunk_id,
|
||||
metadata=chunk_metadata,
|
||||
chunk_metadata=backend_chunk_metadata,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
|||
return list_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_sequence(typ: object) -> bool:
|
||||
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
|
||||
import collections.abc
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is collections.abc.Sequence
|
||||
|
||||
|
||||
def unwrap_generic_sequence(typ: object) -> type:
|
||||
"""
|
||||
Extracts the item type of a Sequence type.
|
||||
|
||||
:param typ: The Sequence type `Sequence[T]`.
|
||||
:returns: The item type `T`.
|
||||
"""
|
||||
|
||||
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _unwrap_generic_sequence(typ: object) -> type:
|
||||
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
|
||||
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return sequence_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
||||
"True if the specified type is a generic set, i.e. `Set[T]`."
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@ from .inspection import (
|
|||
TypeLike,
|
||||
is_generic_dict,
|
||||
is_generic_list,
|
||||
is_generic_sequence,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
unwrap_generic_dict,
|
||||
unwrap_generic_list,
|
||||
unwrap_generic_sequence,
|
||||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
|
|
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
|
|||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
arg = typing.get_args(data_type)[0]
|
||||
return python_type_to_name(arg)
|
||||
return python_type_to_name(arg, force=force)
|
||||
|
||||
if force:
|
||||
# generic types
|
||||
if is_type_optional(data_type, strict=True):
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type))
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
|
||||
return f"Optional__{inner_name}"
|
||||
elif is_generic_list(data_type):
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type))
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_sequence(data_type):
|
||||
# Treat Sequence the same as List for schema generation purposes
|
||||
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_dict(data_type):
|
||||
key_type, value_type = unwrap_generic_dict(data_type)
|
||||
key_name = python_type_to_name(key_type)
|
||||
value_name = python_type_to_name(value_type)
|
||||
key_name = python_type_to_name(key_type, force=True)
|
||||
value_name = python_type_to_name(value_type, force=True)
|
||||
return f"Dict__{key_name}__{value_name}"
|
||||
elif is_type_union(data_type):
|
||||
member_types = unwrap_union_types(data_type)
|
||||
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
|
||||
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
|
||||
return f"Union__{member_names}"
|
||||
|
||||
# named system or user-defined type
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def get_class_property_docstrings(
|
|||
def docstring_to_schema(data_type: type) -> Schema:
|
||||
short_description, long_description = get_class_docstrings(data_type)
|
||||
schema: Schema = {
|
||||
"title": python_type_to_name(data_type),
|
||||
"title": python_type_to_name(data_type, force=True),
|
||||
}
|
||||
|
||||
description = "\n".join(filter(None, [short_description, long_description]))
|
||||
|
|
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
|
|||
if origin_type is list:
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(list_type)}
|
||||
elif origin_type is collections.abc.Sequence:
|
||||
# Treat Sequence the same as list for JSON schema (both are arrays)
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(sequence_type)}
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue