mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 20:32:50 +00:00
feat: Add support for Conversatsions in Responses API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
548ccff368
commit
1e59793288
18 changed files with 662 additions and 10 deletions
|
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
|
|||
)
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
|
|
@ -63,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
|
|
@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
|
|
@ -88,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
|
@ -325,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
|
@ -339,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
conversation,
|
||||
store,
|
||||
stream,
|
||||
temperature,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -61,12 +67,14 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
|
@ -205,6 +213,7 @@ class OpenAIResponsesImpl:
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
|
@ -221,11 +230,22 @@ class OpenAIResponsesImpl:
|
|||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
|
||||
if conversation is not None:
|
||||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
conversation_exists = await self._check_conversation_exists(conversation)
|
||||
if not conversation_exists:
|
||||
raise ConversationNotFoundError(conversation)
|
||||
|
||||
input = await self._load_conversation_context(conversation, input)
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
conversation=conversation,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
|
|
@ -270,6 +290,7 @@ class OpenAIResponsesImpl:
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
|
|
@ -296,7 +317,7 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
response_id = f"resp_{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
|
|
@ -327,5 +348,98 @@ class OpenAIResponsesImpl:
|
|||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
if conversation and final_response:
|
||||
await self._sync_response_to_conversation(conversation, all_input, final_response)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _check_conversation_exists(self, conversation_id: str) -> bool:
|
||||
"""Check if a conversation exists."""
|
||||
try:
|
||||
await self.conversations_api.get_conversation(conversation_id)
|
||||
return True
|
||||
except ConversationNotFoundError:
|
||||
return False
|
||||
|
||||
async def _load_conversation_context(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput]
|
||||
) -> list[OpenAIResponseInput]:
|
||||
"""Load conversation history and merge with provided content."""
|
||||
try:
|
||||
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
|
||||
|
||||
context_messages = []
|
||||
for item in conversation_items.data:
|
||||
if isinstance(item, OpenAIResponseMessage):
|
||||
if item.role == "user":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
elif item.role == "assistant":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
|
||||
# add new content to context
|
||||
if isinstance(content, str):
|
||||
context_messages.append(OpenAIResponseMessage(role="user", content=content))
|
||||
elif isinstance(content, list):
|
||||
context_messages.extend(content)
|
||||
|
||||
return context_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load conversation context for {conversation_id}: {e}")
|
||||
if isinstance(content, str):
|
||||
return [OpenAIResponseMessage(role="user", content=content)]
|
||||
return content
|
||||
|
||||
async def _sync_response_to_conversation(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
try:
|
||||
conversation_items = []
|
||||
|
||||
# add user content message(s)
|
||||
if isinstance(content, str):
|
||||
conversation_items.append(
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, OpenAIResponseMessage) and item.role == "user":
|
||||
if isinstance(item.content, str):
|
||||
conversation_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": item.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "user", "content": item.content})
|
||||
|
||||
# add assistant response message
|
||||
for output_item in response.output:
|
||||
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
|
||||
if hasattr(output_item, "content") and isinstance(output_item.content, list):
|
||||
conversation_items.append(
|
||||
{"type": "message", "role": "assistant", "content": output_item.content}
|
||||
)
|
||||
|
||||
if conversation_items:
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync response {response.id} to conversation {conversation_id}: {e}")
|
||||
# don't fail response creation if conversation sync fails
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue