Merge branch 'main' into pr1573

This commit is contained in:
Xi Yan 2025-03-12 18:23:39 -07:00
commit 8942071b3b
10 changed files with 171 additions and 123 deletions

View file

@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
@ -449,7 +449,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
@ -460,7 +460,6 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""

View file

@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
metrics: Optional[List[MetricInResponse]] = None
@json_schema_type

View file

@ -48,7 +48,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -206,12 +206,12 @@ class InferenceRouter(Inference):
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
@ -238,7 +238,6 @@ class InferenceRouter(Inference):
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:

View file

@ -19,7 +19,7 @@ def preserve_contexts_async_generator(
and we need to preserve the context across the event loop boundary.
"""
async def wrapper():
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()

View file

@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=tool_responses,
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)

View file

@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(