mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
refactor(agents): migrate to OpenAI chat completions API (#3323)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Test Llama Stack Build / build-single-provider (push) Failing after 2s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 8s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 17s
Python Package Build Test / build (3.13) (push) Failing after 14s
Test Llama Stack Build / generate-matrix (push) Successful in 18s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 44s
Pre-commit / pre-commit (push) Successful in 1m16s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Test Llama Stack Build / build-single-provider (push) Failing after 2s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 8s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 17s
Python Package Build Test / build (3.13) (push) Failing after 14s
Test Llama Stack Build / generate-matrix (push) Successful in 18s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 44s
Pre-commit / pre-commit (push) Successful in 1m16s
This commit is contained in:
parent
426dc54883
commit
7e48cc48bc
32 changed files with 12226 additions and 15 deletions
|
@ -50,6 +50,12 @@ from llama_stack.apis.inference import (
|
|||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
|
@ -68,6 +74,11 @@ from llama_stack.models.llama.datatypes import (
|
|||
BuiltinTool,
|
||||
ToolCall,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
|
@ -177,12 +188,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
turn_id = str(uuid.uuid4())
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
@ -505,26 +516,93 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
stop_reason = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_nested(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_nested(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
role = openai_msg.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**openai_msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[OpenAIMessageParam] = [
|
||||
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
||||
]
|
||||
|
||||
# Convert tool definitions to OpenAI format
|
||||
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||
|
||||
# Extract tool_choice from tool_config for OpenAI compatibility
|
||||
# Note: tool_choice can only be provided when tools are also provided
|
||||
tool_choice = None
|
||||
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||
tc = self.agent_config.tool_config.tool_choice
|
||||
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
||||
# Convert tool_choice to OpenAI format
|
||||
if tool_choice_str in ("auto", "none", "required"):
|
||||
tool_choice = tool_choice_str
|
||||
else:
|
||||
# It's a specific tool name, wrap it in the proper format
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
||||
|
||||
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
||||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
tool_config=self.agent_config.tool_config,
|
||||
):
|
||||
)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
)
|
||||
|
||||
async for chunk in response_stream:
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
stop_reason = event.stop_reason or StopReason.end_of_turn
|
||||
continue
|
||||
|
||||
delta = event.delta
|
||||
|
@ -533,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls.append(delta.tool_call)
|
||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||
content = delta.tool_call
|
||||
content = str(delta.tool_call)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -560,9 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue