mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
refactor(agents): migrate to OpenAI chat completions API
Replace chat_completion calls with openai_chat_completion to eliminate dependency on legacy inference APIs.
This commit is contained in:
parent
e980436a2e
commit
c6e980a993
1 changed files with 47 additions and 15 deletions
|
@ -68,6 +68,11 @@ from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict_new,
|
||||||
|
convert_openai_chat_completion_stream,
|
||||||
|
convert_tooldef_to_openai_tool,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
|
@ -177,12 +182,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
|
turn_id = str(uuid.uuid4())
|
||||||
span = tracing.get_current_span()
|
span = tracing.get_current_span()
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
turn_id = str(uuid.uuid4())
|
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
if self.agent_config.name:
|
if self.agent_config.name:
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
@ -505,26 +510,55 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
content = ""
|
content = ""
|
||||||
stop_reason = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
async with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
if self.agent_config.name:
|
if self.agent_config.name:
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
# Convert messages to OpenAI format
|
||||||
self.agent_config.model,
|
openai_messages: list[dict] = []
|
||||||
input_messages,
|
for message in input_messages:
|
||||||
tools=self.tool_defs,
|
openai_messages.append(await convert_message_to_openai_dict_new(message))
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
|
||||||
|
# Convert tool definitions to OpenAI format
|
||||||
|
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||||
|
|
||||||
|
# Extract tool_choice from tool_config for OpenAI compatibility
|
||||||
|
# Note: tool_choice can only be provided when tools are also provided
|
||||||
|
tool_choice = None
|
||||||
|
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||||
|
tc = self.agent_config.tool_config.tool_choice
|
||||||
|
tool_choice = tc.value if hasattr(tc, "value") else str(tc)
|
||||||
|
|
||||||
|
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||||
|
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||||
|
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
||||||
|
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||||
|
|
||||||
|
# Use OpenAI chat completion
|
||||||
|
openai_stream = await self.inference_api.openai_chat_completion(
|
||||||
|
model=self.agent_config.model,
|
||||||
|
messages=openai_messages,
|
||||||
|
tools=openai_tools if openai_tools else None,
|
||||||
|
tool_choice=tool_choice,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
)
|
||||||
tool_config=self.agent_config.tool_config,
|
|
||||||
):
|
# Convert OpenAI stream back to Llama Stack format
|
||||||
|
response_stream = convert_openai_chat_completion_stream(
|
||||||
|
openai_stream, enable_incremental_tool_calls=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response_stream:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
continue
|
continue
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = event.stop_reason or StopReason.end_of_turn
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
|
@ -533,7 +567,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_calls.append(delta.tool_call)
|
tool_calls.append(delta.tool_call)
|
||||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||||
content = delta.tool_call
|
content = str(delta.tool_call)
|
||||||
if stream:
|
if stream:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -560,9 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||||
|
|
||||||
if event.stop_reason is not None:
|
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||||
stop_reason = event.stop_reason
|
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input",
|
"input",
|
||||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue