mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 11:52:37 +00:00
feat: Add support for Conversatsions in Responses API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
548ccff368
commit
1e59793288
18 changed files with 662 additions and 10 deletions
4
docs/static/deprecated-llama-stack-spec.html
vendored
4
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -10083,6 +10083,10 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
||||||
},
|
},
|
||||||
|
"conversation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
|
||||||
|
},
|
||||||
"store": {
|
"store": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
6
docs/static/deprecated-llama-stack-spec.yaml
vendored
6
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -7493,6 +7493,12 @@ components:
|
||||||
(Optional) if specified, the new response will be a continuation of the
|
(Optional) if specified, the new response will be a continuation of the
|
||||||
previous response. This can be used to easily fork-off new responses from
|
previous response. This can be used to easily fork-off new responses from
|
||||||
existing responses.
|
existing responses.
|
||||||
|
conversation:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The ID of a conversation to add the response to. Must begin
|
||||||
|
with 'conv_'. Input and output messages will be automatically added to
|
||||||
|
the conversation.
|
||||||
store:
|
store:
|
||||||
type: boolean
|
type: boolean
|
||||||
stream:
|
stream:
|
||||||
|
|
|
||||||
4
docs/static/llama-stack-spec.html
vendored
4
docs/static/llama-stack-spec.html
vendored
|
|
@ -8178,6 +8178,10 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
||||||
},
|
},
|
||||||
|
"conversation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
|
||||||
|
},
|
||||||
"store": {
|
"store": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
6
docs/static/llama-stack-spec.yaml
vendored
6
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -6189,6 +6189,12 @@ components:
|
||||||
(Optional) if specified, the new response will be a continuation of the
|
(Optional) if specified, the new response will be a continuation of the
|
||||||
previous response. This can be used to easily fork-off new responses from
|
previous response. This can be used to easily fork-off new responses from
|
||||||
existing responses.
|
existing responses.
|
||||||
|
conversation:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The ID of a conversation to add the response to. Must begin
|
||||||
|
with 'conv_'. Input and output messages will be automatically added to
|
||||||
|
the conversation.
|
||||||
store:
|
store:
|
||||||
type: boolean
|
type: boolean
|
||||||
stream:
|
stream:
|
||||||
|
|
|
||||||
4
docs/static/stainless-llama-stack-spec.html
vendored
4
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -10187,6 +10187,10 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
||||||
},
|
},
|
||||||
|
"conversation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation."
|
||||||
|
},
|
||||||
"store": {
|
"store": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
6
docs/static/stainless-llama-stack-spec.yaml
vendored
6
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -7634,6 +7634,12 @@ components:
|
||||||
(Optional) if specified, the new response will be a continuation of the
|
(Optional) if specified, the new response will be a continuation of the
|
||||||
previous response. This can be used to easily fork-off new responses from
|
previous response. This can be used to easily fork-off new responses from
|
||||||
existing responses.
|
existing responses.
|
||||||
|
conversation:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The ID of a conversation to add the response to. Must begin
|
||||||
|
with 'conv_'. Input and output messages will be automatically added to
|
||||||
|
the conversation.
|
||||||
store:
|
store:
|
||||||
type: boolean
|
type: boolean
|
||||||
stream:
|
stream:
|
||||||
|
|
|
||||||
|
|
@ -812,6 +812,7 @@ class Agents(Protocol):
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
|
conversation: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
|
@ -831,6 +832,7 @@ class Agents(Protocol):
|
||||||
:param input: Input message(s) to create the response.
|
:param input: Input message(s) to create the response.
|
||||||
:param model: The underlying LLM used for completions.
|
:param model: The underlying LLM used for completions.
|
||||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||||
|
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||||
:param include: (Optional) Additional fields to include in the response.
|
:param include: (Optional) Additional fields to include in the response.
|
||||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
|
|
|
||||||
|
|
@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
|
||||||
|
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationNotFoundError(ResourceNotFoundError):
|
||||||
|
"""raised when Llama Stack cannot find a referenced conversation"""
|
||||||
|
|
||||||
|
def __init__(self, conversation_id: str) -> None:
|
||||||
|
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidConversationIdError(ValueError):
|
||||||
|
"""raised when a conversation ID has an invalid format"""
|
||||||
|
|
||||||
|
def __init__(self, conversation_id: str) -> None:
|
||||||
|
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,7 @@ async def resolve_impls(
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
|
internal_impls: dict[Api, Any] | None = None,
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Resolves provider implementations by:
|
Resolves provider implementations by:
|
||||||
|
|
@ -172,7 +173,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
|
@ -280,9 +281,10 @@ async def instantiate_providers(
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
|
internal_impls: dict[Api, Any] | None = None,
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
|
||||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
# Skip providers that are not enabled
|
# Skip providers that are not enabled
|
||||||
|
|
|
||||||
|
|
@ -326,12 +326,17 @@ class Stack:
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
impls = await resolve_impls(
|
|
||||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
internal_impls = {}
|
||||||
add_internal_implementations(impls, self.run_config)
|
add_internal_implementations(internal_impls, self.run_config)
|
||||||
|
|
||||||
|
impls = await resolve_impls(
|
||||||
|
self.run_config,
|
||||||
|
self.provider_registry or get_provider_registry(self.run_config),
|
||||||
|
dist_registry,
|
||||||
|
policy,
|
||||||
|
internal_impls,
|
||||||
|
)
|
||||||
|
|
||||||
if Api.prompts in impls:
|
if Api.prompts in impls:
|
||||||
await impls[Api.prompts].initialize()
|
await impls[Api.prompts].initialize()
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
|
deps[Api.conversations],
|
||||||
policy,
|
policy,
|
||||||
Api.telemetry in deps,
|
Api.telemetry in deps,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
|
@ -63,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
conversations_api: Conversations,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
telemetry_enabled: bool = False,
|
telemetry_enabled: bool = False,
|
||||||
):
|
):
|
||||||
|
|
@ -72,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.conversations_api = conversations_api
|
||||||
self.telemetry_enabled = telemetry_enabled
|
self.telemetry_enabled = telemetry_enabled
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
|
|
@ -88,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
responses_store=self.responses_store,
|
responses_store=self.responses_store,
|
||||||
vector_io_api=self.vector_io_api,
|
vector_io_api=self.vector_io_api,
|
||||||
|
conversations_api=self.conversations_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
|
|
@ -325,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
|
conversation: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
|
@ -339,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
model,
|
model,
|
||||||
instructions,
|
instructions,
|
||||||
previous_response_id,
|
previous_response_id,
|
||||||
|
conversation,
|
||||||
store,
|
store,
|
||||||
stream,
|
stream,
|
||||||
temperature,
|
temperature,
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,12 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
OpenAIResponseTextFormat,
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.errors import (
|
||||||
|
ConversationNotFoundError,
|
||||||
|
InvalidConversationIdError,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
|
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
|
@ -61,12 +67,14 @@ class OpenAIResponsesImpl:
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
responses_store: ResponsesStore,
|
responses_store: ResponsesStore,
|
||||||
vector_io_api: VectorIO, # VectorIO
|
vector_io_api: VectorIO, # VectorIO
|
||||||
|
conversations_api: Conversations,
|
||||||
):
|
):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.responses_store = responses_store
|
self.responses_store = responses_store
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
|
self.conversations_api = conversations_api
|
||||||
self.tool_executor = ToolExecutor(
|
self.tool_executor = ToolExecutor(
|
||||||
tool_groups_api=tool_groups_api,
|
tool_groups_api=tool_groups_api,
|
||||||
tool_runtime_api=tool_runtime_api,
|
tool_runtime_api=tool_runtime_api,
|
||||||
|
|
@ -205,6 +213,7 @@ class OpenAIResponsesImpl:
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
|
conversation: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
|
@ -221,11 +230,22 @@ class OpenAIResponsesImpl:
|
||||||
if shields is not None:
|
if shields is not None:
|
||||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||||
|
|
||||||
|
if conversation is not None:
|
||||||
|
if not conversation.startswith("conv_"):
|
||||||
|
raise InvalidConversationIdError(conversation)
|
||||||
|
|
||||||
|
conversation_exists = await self._check_conversation_exists(conversation)
|
||||||
|
if not conversation_exists:
|
||||||
|
raise ConversationNotFoundError(conversation)
|
||||||
|
|
||||||
|
input = await self._load_conversation_context(conversation, input)
|
||||||
|
|
||||||
stream_gen = self._create_streaming_response(
|
stream_gen = self._create_streaming_response(
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
|
conversation=conversation,
|
||||||
store=store,
|
store=store,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
text=text,
|
text=text,
|
||||||
|
|
@ -270,6 +290,7 @@ class OpenAIResponsesImpl:
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
|
conversation: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
text: OpenAIResponseText | None = None,
|
text: OpenAIResponseText | None = None,
|
||||||
|
|
@ -296,7 +317,7 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create orchestrator and delegate streaming logic
|
# Create orchestrator and delegate streaming logic
|
||||||
response_id = f"resp-{uuid.uuid4()}"
|
response_id = f"resp_{uuid.uuid4()}"
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
|
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
orchestrator = StreamingResponseOrchestrator(
|
||||||
|
|
@ -327,5 +348,98 @@ class OpenAIResponsesImpl:
|
||||||
messages=orchestrator.final_messages,
|
messages=orchestrator.final_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conversation and final_response:
|
||||||
|
await self._sync_response_to_conversation(conversation, all_input, final_response)
|
||||||
|
|
||||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
return await self.responses_store.delete_response_object(response_id)
|
return await self.responses_store.delete_response_object(response_id)
|
||||||
|
|
||||||
|
async def _check_conversation_exists(self, conversation_id: str) -> bool:
|
||||||
|
"""Check if a conversation exists."""
|
||||||
|
try:
|
||||||
|
await self.conversations_api.get_conversation(conversation_id)
|
||||||
|
return True
|
||||||
|
except ConversationNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _load_conversation_context(
|
||||||
|
self, conversation_id: str, content: str | list[OpenAIResponseInput]
|
||||||
|
) -> list[OpenAIResponseInput]:
|
||||||
|
"""Load conversation history and merge with provided content."""
|
||||||
|
try:
|
||||||
|
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
|
||||||
|
|
||||||
|
context_messages = []
|
||||||
|
for item in conversation_items.data:
|
||||||
|
if isinstance(item, OpenAIResponseMessage):
|
||||||
|
if item.role == "user":
|
||||||
|
context_messages.append(
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif item.role == "assistant":
|
||||||
|
context_messages.append(
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# add new content to context
|
||||||
|
if isinstance(content, str):
|
||||||
|
context_messages.append(OpenAIResponseMessage(role="user", content=content))
|
||||||
|
elif isinstance(content, list):
|
||||||
|
context_messages.extend(content)
|
||||||
|
|
||||||
|
return context_messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load conversation context for {conversation_id}: {e}")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return [OpenAIResponseMessage(role="user", content=content)]
|
||||||
|
return content
|
||||||
|
|
||||||
|
async def _sync_response_to_conversation(
|
||||||
|
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
|
||||||
|
) -> None:
|
||||||
|
"""Sync content and response messages to the conversation."""
|
||||||
|
try:
|
||||||
|
conversation_items = []
|
||||||
|
|
||||||
|
# add user content message(s)
|
||||||
|
if isinstance(content, str):
|
||||||
|
conversation_items.append(
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||||
|
)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, OpenAIResponseMessage) and item.role == "user":
|
||||||
|
if isinstance(item.content, str):
|
||||||
|
conversation_items.append(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "input_text", "text": item.content}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(item.content, list):
|
||||||
|
conversation_items.append({"type": "message", "role": "user", "content": item.content})
|
||||||
|
|
||||||
|
# add assistant response message
|
||||||
|
for output_item in response.output:
|
||||||
|
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
|
||||||
|
if hasattr(output_item, "content") and isinstance(output_item.content, list):
|
||||||
|
conversation_items.append(
|
||||||
|
{"type": "message", "role": "assistant", "content": output_item.content}
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation_items:
|
||||||
|
adapter = TypeAdapter(list[ConversationItem])
|
||||||
|
validated_items = adapter.validate_python(conversation_items)
|
||||||
|
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sync response {response.id} to conversation {conversation_id}: {e}")
|
||||||
|
# don't fail response creation if conversation sync fails
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
Api.vector_dbs,
|
Api.vector_dbs,
|
||||||
Api.tool_runtime,
|
Api.tool_runtime,
|
||||||
Api.tool_groups,
|
Api.tool_groups,
|
||||||
|
Api.conversations,
|
||||||
],
|
],
|
||||||
optional_api_dependencies=[
|
optional_api_dependencies=[
|
||||||
Api.telemetry,
|
Api.telemetry,
|
||||||
|
|
|
||||||
128
tests/integration/responses/test_conversation_responses.py
Normal file
128
tests/integration/responses/test_conversation_responses.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestConversationResponses:
|
||||||
|
"""Integration tests for the conversation parameter in responses API."""
|
||||||
|
|
||||||
|
def test_conversation_basic_workflow(self, openai_client, text_model_id):
|
||||||
|
"""Test basic conversation workflow: create conversation, add response, verify sync."""
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "test"})
|
||||||
|
assert conversation.id.startswith("conv_")
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "What are the 5 Ds of dodgeball?"}],
|
||||||
|
conversation=conversation.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.id.startswith("resp_")
|
||||||
|
assert len(response.output_text.strip()) > 0
|
||||||
|
|
||||||
|
# Verify conversation was synced
|
||||||
|
conversation_items = openai_client.conversations.items.list(conversation.id)
|
||||||
|
assert len(conversation_items.data) >= 2
|
||||||
|
|
||||||
|
roles = [item.role for item in conversation_items.data if hasattr(item, "role")]
|
||||||
|
assert "user" in roles and "assistant" in roles
|
||||||
|
|
||||||
|
def test_conversation_multi_turn_and_streaming(self, openai_client, text_model_id):
|
||||||
|
"""Test multi-turn conversations and streaming responses."""
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
# First turn
|
||||||
|
response1 = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "Say hello"}],
|
||||||
|
conversation=conversation.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second turn with streaming
|
||||||
|
response_stream = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "Say goodbye"}],
|
||||||
|
conversation=conversation.id,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = None
|
||||||
|
for chunk in response_stream:
|
||||||
|
if chunk.type == "response.completed":
|
||||||
|
final_response = chunk.response
|
||||||
|
break
|
||||||
|
|
||||||
|
assert response1.id != final_response.id
|
||||||
|
assert len(response1.output_text.strip()) > 0
|
||||||
|
assert len(final_response.output_text.strip()) > 0
|
||||||
|
|
||||||
|
# Verify all turns are in conversation
|
||||||
|
conversation_items = openai_client.conversations.items.list(conversation.id)
|
||||||
|
assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages
|
||||||
|
|
||||||
|
def test_conversation_context_loading(self, openai_client, text_model_id):
|
||||||
|
"""Test that conversation context is properly loaded for responses."""
|
||||||
|
conversation = openai_client.conversations.create(
|
||||||
|
items=[
|
||||||
|
{"type": "message", "role": "user", "content": "My name is Alice"},
|
||||||
|
{"type": "message", "role": "assistant", "content": "Hello Alice!"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "What's my name?"}],
|
||||||
|
conversation=conversation.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "alice" in response.output_text.lower()
|
||||||
|
|
||||||
|
def test_conversation_error_handling(self, openai_client, text_model_id):
|
||||||
|
"""Test error handling for invalid and nonexistent conversations."""
|
||||||
|
# Invalid conversation ID format
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "Hello"}],
|
||||||
|
conversation="invalid_id",
|
||||||
|
)
|
||||||
|
assert any(word in str(exc_info.value).lower() for word in ["conv", "invalid", "bad"])
|
||||||
|
|
||||||
|
# Nonexistent conversation ID
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[{"role": "user", "content": "Hello"}],
|
||||||
|
conversation="conv_nonexistent123",
|
||||||
|
)
|
||||||
|
assert any(word in str(exc_info.value).lower() for word in ["not found", "404"])
|
||||||
|
|
||||||
|
def test_conversation_backward_compatibility(self, openai_client, text_model_id):
|
||||||
|
"""Test that responses work without conversation parameter (backward compatibility)."""
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=text_model_id, input=[{"role": "user", "content": "Hello world"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.id.startswith("resp_")
|
||||||
|
assert len(response.output_text.strip()) > 0
|
||||||
|
|
||||||
|
def test_conversation_compat_client(self, compat_client, text_model_id):
|
||||||
|
"""Test conversation parameter works with compatibility client."""
|
||||||
|
if not hasattr(compat_client, "conversations"):
|
||||||
|
pytest.skip("compat_client does not support conversations API")
|
||||||
|
|
||||||
|
conversation = compat_client.conversations.create()
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id, input="Tell me a joke", conversation=conversation.id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert len(response.output_text.strip()) > 0
|
||||||
|
|
||||||
|
conversation_items = compat_client.conversations.items.list(conversation.id)
|
||||||
|
assert len(conversation_items.data) >= 2
|
||||||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentCreateResponse,
|
AgentCreateResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||||
|
|
@ -33,6 +34,7 @@ def mock_apis():
|
||||||
"safety_api": AsyncMock(spec=Safety),
|
"safety_api": AsyncMock(spec=Safety),
|
||||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||||
|
"conversations_api": AsyncMock(spec=Conversations),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -59,7 +61,8 @@ async def agents_impl(config, mock_apis):
|
||||||
mock_apis["safety_api"],
|
mock_apis["safety_api"],
|
||||||
mock_apis["tool_runtime_api"],
|
mock_apis["tool_runtime_api"],
|
||||||
mock_apis["tool_groups_api"],
|
mock_apis["tool_groups_api"],
|
||||||
{},
|
mock_apis["conversations_api"],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
yield impl
|
yield impl
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,332 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.common.errors import (
|
||||||
|
ConversationNotFoundError,
|
||||||
|
InvalidConversationIdError,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.conversations.conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationItemList,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import existing fixtures from the main responses test file
|
||||||
|
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||||
|
OpenAIResponsesImpl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def responses_impl_with_conversations(
|
||||||
|
mock_inference_api,
|
||||||
|
mock_tool_groups_api,
|
||||||
|
mock_tool_runtime_api,
|
||||||
|
mock_responses_store,
|
||||||
|
mock_vector_io_api,
|
||||||
|
mock_conversations_api,
|
||||||
|
):
|
||||||
|
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||||
|
return OpenAIResponsesImpl(
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
tool_groups_api=mock_tool_groups_api,
|
||||||
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
|
responses_store=mock_responses_store,
|
||||||
|
vector_io_api=mock_vector_io_api,
|
||||||
|
conversations_api=mock_conversations_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationValidation:
|
||||||
|
"""Test conversation ID validation logic."""
|
||||||
|
|
||||||
|
async def test_conversation_existence_check_valid(self, responses_impl_with_conversations, mock_conversations_api):
|
||||||
|
"""Test conversation existence check for valid conversation."""
|
||||||
|
conv_id = "conv_valid123"
|
||||||
|
|
||||||
|
# Mock successful conversation retrieval
|
||||||
|
mock_conversations_api.get_conversation.return_value = Conversation(
|
||||||
|
id=conv_id, created_at=1234567890, metadata={}, object="conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._check_conversation_exists(conv_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_conversations_api.get_conversation.assert_called_once_with(conv_id)
|
||||||
|
|
||||||
|
async def test_conversation_existence_check_invalid(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test conversation existence check for non-existent conversation."""
|
||||||
|
conv_id = "conv_nonexistent"
|
||||||
|
|
||||||
|
# Mock conversation not found
|
||||||
|
mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._check_conversation_exists(conv_id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_conversations_api.get_conversation.assert_called_once_with(conv_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationContextLoading:
|
||||||
|
"""Test conversation context loading functionality."""
|
||||||
|
|
||||||
|
async def test_load_conversation_context_simple_input(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test loading conversation context with simple string input."""
|
||||||
|
conv_id = "conv_test123"
|
||||||
|
input_text = "Hello, how are you?"
|
||||||
|
|
||||||
|
# mock items in chronological order (a consequence of order="asc")
|
||||||
|
mock_conversation_items = ConversationItemList(
|
||||||
|
data=[
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg_1",
|
||||||
|
content=[{"type": "input_text", "text": "Previous user message"}],
|
||||||
|
role="user",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
),
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg_2",
|
||||||
|
content=[{"type": "output_text", "text": "Previous assistant response"}],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
first_id="msg_1",
|
||||||
|
has_more=False,
|
||||||
|
last_id="msg_2",
|
||||||
|
object="list",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_conversations_api.list.return_value = mock_conversation_items
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||||
|
|
||||||
|
# should have conversation history + new input
|
||||||
|
assert len(result) == 3
|
||||||
|
assert isinstance(result[0], OpenAIResponseMessage)
|
||||||
|
assert result[0].role == "user"
|
||||||
|
assert isinstance(result[1], OpenAIResponseMessage)
|
||||||
|
assert result[1].role == "assistant"
|
||||||
|
assert isinstance(result[2], OpenAIResponseMessage)
|
||||||
|
assert result[2].role == "user"
|
||||||
|
assert result[2].content == input_text
|
||||||
|
|
||||||
|
async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api):
|
||||||
|
"""Test loading conversation context when API call fails."""
|
||||||
|
conv_id = "conv_test123"
|
||||||
|
input_text = "Hello"
|
||||||
|
|
||||||
|
mock_conversations_api.list.side_effect = Exception("API Error")
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0], OpenAIResponseMessage)
|
||||||
|
assert result[0].role == "user"
|
||||||
|
assert result[0].content == input_text
|
||||||
|
|
||||||
|
async def test_load_conversation_context_with_list_input(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test loading conversation context with list input."""
|
||||||
|
conv_id = "conv_test123"
|
||||||
|
input_messages = [
|
||||||
|
OpenAIResponseMessage(role="user", content="First message"),
|
||||||
|
OpenAIResponseMessage(role="user", content="Second message"),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_conversations_api.list.return_value = ConversationItemList(
|
||||||
|
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result == input_messages
|
||||||
|
|
||||||
|
async def test_load_conversation_context_empty_conversation(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test loading context from empty conversation."""
|
||||||
|
conv_id = "conv_empty"
|
||||||
|
input_text = "Hello"
|
||||||
|
|
||||||
|
mock_conversations_api.list.return_value = ConversationItemList(
|
||||||
|
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].role == "user"
|
||||||
|
assert result[0].content == input_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageSyncing:
|
||||||
|
"""Test message syncing to conversations."""
|
||||||
|
|
||||||
|
async def test_sync_response_to_conversation_simple(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test syncing simple response to conversation."""
|
||||||
|
conv_id = "conv_test123"
|
||||||
|
input_text = "What are the 5 Ds of dodgeball?"
|
||||||
|
|
||||||
|
# mock response
|
||||||
|
mock_response = OpenAIResponseObject(
|
||||||
|
id="resp_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
model="test-model",
|
||||||
|
object="response",
|
||||||
|
output=[
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg_response",
|
||||||
|
content=[
|
||||||
|
OpenAIResponseOutputMessageContentOutputText(
|
||||||
|
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response)
|
||||||
|
|
||||||
|
# should call add_items with user input and assistant response
|
||||||
|
mock_conversations_api.add_items.assert_called_once()
|
||||||
|
call_args = mock_conversations_api.add_items.call_args
|
||||||
|
|
||||||
|
assert call_args[0][0] == conv_id # conversation_id
|
||||||
|
items = call_args[0][1] # conversation_items
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
# User message
|
||||||
|
assert items[0].type == "message"
|
||||||
|
assert items[0].role == "user"
|
||||||
|
assert items[0].content[0].type == "input_text"
|
||||||
|
assert items[0].content[0].text == input_text
|
||||||
|
|
||||||
|
# Assistant message
|
||||||
|
assert items[1].type == "message"
|
||||||
|
assert items[1].role == "assistant"
|
||||||
|
|
||||||
|
async def test_sync_response_to_conversation_api_error(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test syncing when conversations API call fails."""
|
||||||
|
conv_id = "conv_test123"
|
||||||
|
|
||||||
|
mock_response = OpenAIResponseObject(
|
||||||
|
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock API error
|
||||||
|
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||||
|
|
||||||
|
# Should not raise exception (graceful failure)
|
||||||
|
result = await responses_impl_with_conversations._sync_response_to_conversation(conv_id, "Hello", mock_response)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationWorkflow:
|
||||||
|
"""Integration tests for the full conversation workflow."""
|
||||||
|
|
||||||
|
async def test_create_response_with_valid_conversation(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test creating a response with a valid conversation parameter."""
|
||||||
|
mock_conversations_api.get_conversation.return_value = Conversation(
|
||||||
|
id="conv_test123", created_at=1234567890, metadata={}, object="conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_conversations_api.list.return_value = ConversationItemList(
|
||||||
|
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_streaming_response(*args, **kwargs):
|
||||||
|
mock_response = OpenAIResponseObject(
|
||||||
|
id="resp_test123",
|
||||||
|
created_at=1234567890,
|
||||||
|
model="test-model",
|
||||||
|
object="response",
|
||||||
|
output=[
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg_response",
|
||||||
|
content=[
|
||||||
|
OpenAIResponseOutputMessageContentOutputText(
|
||||||
|
text="Test response", type="output_text", annotations=[]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
|
||||||
|
|
||||||
|
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
|
||||||
|
|
||||||
|
input_text = "Hello, how are you?"
|
||||||
|
conversation_id = "conv_test123"
|
||||||
|
|
||||||
|
response = await responses_impl_with_conversations.create_openai_response(
|
||||||
|
input=input_text, model="test-model", conversation=conversation_id, stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.id == "resp_test123"
|
||||||
|
|
||||||
|
mock_conversations_api.get_conversation.assert_called_once_with(conversation_id)
|
||||||
|
|
||||||
|
mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc")
|
||||||
|
|
||||||
|
# Note: conversation sync happens in the streaming response flow,
|
||||||
|
# which is complex to mock fully in this unit test
|
||||||
|
|
||||||
|
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||||
|
"""Test creating a response with an invalid conversation ID."""
|
||||||
|
with pytest.raises(InvalidConversationIdError) as exc_info:
|
||||||
|
await responses_impl_with_conversations.create_openai_response(
|
||||||
|
input="Hello", model="test-model", conversation="invalid_id", stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_create_response_with_nonexistent_conversation(
|
||||||
|
self, responses_impl_with_conversations, mock_conversations_api
|
||||||
|
):
|
||||||
|
"""Test creating a response with a non-existent conversation."""
|
||||||
|
mock_conversations_api.get_conversation.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||||
|
|
||||||
|
with pytest.raises(ConversationNotFoundError) as exc_info:
|
||||||
|
await responses_impl_with_conversations.create_openai_response(
|
||||||
|
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "not found" in str(exc_info.value)
|
||||||
|
|
@ -83,9 +83,21 @@ def mock_vector_io_api():
|
||||||
return vector_io_api
|
return vector_io_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_conversations_api():
|
||||||
|
"""Mock conversations API for testing."""
|
||||||
|
mock_api = AsyncMock()
|
||||||
|
return mock_api
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_responses_impl(
|
def openai_responses_impl(
|
||||||
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
mock_inference_api,
|
||||||
|
mock_tool_groups_api,
|
||||||
|
mock_tool_runtime_api,
|
||||||
|
mock_responses_store,
|
||||||
|
mock_vector_io_api,
|
||||||
|
mock_conversations_api,
|
||||||
):
|
):
|
||||||
return OpenAIResponsesImpl(
|
return OpenAIResponsesImpl(
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
|
|
@ -93,6 +105,7 @@ def openai_responses_impl(
|
||||||
tool_runtime_api=mock_tool_runtime_api,
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
responses_store=mock_responses_store,
|
responses_store=mock_responses_store,
|
||||||
vector_io_api=mock_vector_io_api,
|
vector_io_api=mock_vector_io_api,
|
||||||
|
conversations_api=mock_conversations_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue