mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
fully convert the input message types to the openai message types
previously the types were passed around as TypedDicts and converted to dicts
This commit is contained in:
parent
0327ef3daf
commit
0399f89c1f
32 changed files with 12183 additions and 4 deletions
|
@ -50,6 +50,12 @@ from llama_stack.apis.inference import (
|
|||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
|
@ -515,10 +521,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_nested(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_nested(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
role = openai_msg.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**openai_msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[dict] = []
|
||||
for message in input_messages:
|
||||
openai_messages.append(await convert_message_to_openai_dict_new(message))
|
||||
openai_messages: list[OpenAIMessageParam] = [
|
||||
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
||||
]
|
||||
|
||||
# Convert tool definitions to OpenAI format
|
||||
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||
|
@ -528,7 +566,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_choice = None
|
||||
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||
tc = self.agent_config.tool_config.tool_choice
|
||||
tool_choice = tc.value if hasattr(tc, "value") else str(tc)
|
||||
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
||||
# Convert tool_choice to OpenAI format
|
||||
if tool_choice_str in ("auto", "none", "required"):
|
||||
tool_choice = tool_choice_str
|
||||
else:
|
||||
# It's a specific tool name, wrap it in the proper format
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
||||
|
||||
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue