diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fde38515b..209f5dbd8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -68,6 +68,11 @@ from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, ) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict_new, + convert_openai_chat_completion_stream, + convert_tooldef_to_openai_tool, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.telemetry import tracing @@ -177,12 +182,12 @@ class ChatAgent(ShieldRunnerMixin): return messages async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: + turn_id = str(uuid.uuid4()) span = tracing.get_current_span() if span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) - turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) @@ -505,26 +510,55 @@ class ChatAgent(ShieldRunnerMixin): tool_calls = [] content = "" - stop_reason = None + stop_reason: StopReason | None = None async with tracing.span("inference") as span: if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) - async for chunk in await self.inference_api.chat_completion( - self.agent_config.model, - input_messages, - tools=self.tool_defs, - tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, + # Convert messages to OpenAI format + openai_messages: list[dict] = [] + for message in input_messages: + openai_messages.append(await convert_message_to_openai_dict_new(message)) + + # Convert tool definitions to OpenAI format + openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])] + + # Extract tool_choice from tool_config for OpenAI compatibility + # Note: tool_choice can only be provided when tools are also provided + tool_choice = None + if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice: + tc = self.agent_config.tool_config.tool_choice + tool_choice = tc.value if hasattr(tc, "value") else str(tc) + + # Convert sampling params to OpenAI format (temperature, top_p, max_tokens) + temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None) + top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None) + max_tokens = getattr(sampling_params, "max_tokens", None) + + # Use OpenAI chat completion + openai_stream = await self.inference_api.openai_chat_completion( + model=self.agent_config.model, + messages=openai_messages, + tools=openai_tools if openai_tools else None, + tool_choice=tool_choice, response_format=self.agent_config.response_format, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, stream=True, - sampling_params=sampling_params, - tool_config=self.agent_config.tool_config, - ): + ) + + # Convert OpenAI stream back to Llama Stack format + response_stream = convert_openai_chat_completion_stream( + openai_stream, enable_incremental_tool_calls=True + ) + + async for chunk in response_stream: event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: continue elif event.event_type == ChatCompletionResponseEventType.complete: - stop_reason = StopReason.end_of_turn + stop_reason = event.stop_reason or StopReason.end_of_turn continue delta = event.delta @@ -533,7 +567,7 @@ class ChatAgent(ShieldRunnerMixin): tool_calls.append(delta.tool_call) elif delta.parse_status == ToolCallParseStatus.failed: # If we cannot parse the tools, set the content to the unparsed raw text - content = delta.tool_call + content = str(delta.tool_call) if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -560,9 +594,7 @@ class ChatAgent(ShieldRunnerMixin): else: raise ValueError(f"Unexpected delta type {type(delta)}") - if event.stop_reason is not None: - stop_reason = event.stop_reason - span.set_attribute("stop_reason", stop_reason) + span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn) span.set_attribute( "input", json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),