mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Stub in an initial OpenAI Responses API
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
c149cf2e0f
commit
70c088af3a
18 changed files with 441 additions and 0 deletions
|
@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::vllm`, `inline::sentence-transformers` |
|
| inference | `remote::vllm`, `inline::sentence-transformers` |
|
||||||
|
| openai_responses | `inline::openai-responses` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -24,6 +24,7 @@ class Api(Enum):
|
||||||
eval = "eval"
|
eval = "eval"
|
||||||
post_training = "post_training"
|
post_training = "post_training"
|
||||||
tool_runtime = "tool_runtime"
|
tool_runtime = "tool_runtime"
|
||||||
|
openai_responses = "openai_responses"
|
||||||
|
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/openai_responses/__init__.py
Normal file
7
llama_stack/apis/openai_responses/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .openai_responses import * # noqa: F401 F403
|
91
llama_stack/apis/openai_responses/openai_responses.py
Normal file
91
llama_stack/apis/openai_responses/openai_responses.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncIterator, List, Literal, Optional, Protocol, Union, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseError(BaseModel):
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["output_text"] = "output_text"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutputMessageContent = Annotated[
|
||||||
|
Union[OpenAIResponseOutputMessageContentOutputText,],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessage(BaseModel):
|
||||||
|
id: str
|
||||||
|
content: List[OpenAIResponseOutputMessageContent]
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
status: str
|
||||||
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutput = Annotated[
|
||||||
|
Union[OpenAIResponseOutputMessage,],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObject(BaseModel):
|
||||||
|
created_at: int
|
||||||
|
error: Optional[OpenAIResponseError] = None
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
object: Literal["response"] = "response"
|
||||||
|
output: List[OpenAIResponseOutput]
|
||||||
|
parallel_tool_calls: bool = False
|
||||||
|
previous_response_id: Optional[str] = None
|
||||||
|
status: str
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
truncation: Optional[str] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStream(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class OpenAIResponses(Protocol):
|
||||||
|
"""
|
||||||
|
OpenAI Responses API implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
) -> OpenAIResponseObject: ...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str,
|
||||||
|
model: str,
|
||||||
|
previous_response_id: Optional[str] = None,
|
||||||
|
store: Optional[bool] = True,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ...
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.apis.openai_responses.openai_responses import OpenAIResponses
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
@ -80,6 +81,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.tool_groups: ToolGroups,
|
Api.tool_groups: ToolGroups,
|
||||||
Api.tool_runtime: ToolRuntime,
|
Api.tool_runtime: ToolRuntime,
|
||||||
Api.files: Files,
|
Api.files: Files,
|
||||||
|
Api.openai_responses: OpenAIResponses,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -149,6 +149,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p.benchmark_store = self
|
p.benchmark_store = self
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
p.tool_store = self
|
p.tool_store = self
|
||||||
|
elif api == Api.openai_responses:
|
||||||
|
p.model_store = self
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
|
19
llama_stack/providers/inline/openai_responses/__init__.py
Normal file
19
llama_stack/providers/inline/openai_responses/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
|
||||||
|
from .config import OpenAIResponsesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: OpenAIResponsesImplConfig, deps: Dict[Api, Any]):
|
||||||
|
from .openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
|
impl = OpenAIResponsesImpl(config, deps[Api.models], deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
24
llama_stack/providers/inline/openai_responses/config.py
Normal file
24
llama_stack/providers/inline/openai_responses/config.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsesImplConfig(BaseModel):
|
||||||
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="openai_responses.db",
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import AsyncIterator, List, Optional, cast
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models.models import Models, ModelType
|
||||||
|
from llama_stack.apis.openai_responses import OpenAIResponses
|
||||||
|
from llama_stack.apis.openai_responses.openai_responses import (
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseOutputMessage,
|
||||||
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
from .config import OpenAIResponsesImplConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsesImpl(OpenAIResponses):
|
||||||
|
def __init__(self, config: OpenAIResponsesImplConfig, models_api: Models, inference_api: Inference):
|
||||||
|
self.config = config
|
||||||
|
self.models_api = models_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("OpenAIResponsesImpl.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||||
|
response_json = await self.kvstore.get(key=key)
|
||||||
|
if response_json is None:
|
||||||
|
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||||
|
return OpenAIResponseObject.model_validate_json(response_json)
|
||||||
|
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str,
|
||||||
|
model: str,
|
||||||
|
previous_response_id: Optional[str] = None,
|
||||||
|
store: Optional[bool] = True,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
model_obj = await self.models_api.get_model(model)
|
||||||
|
if model_obj is None:
|
||||||
|
raise ValueError(f"Model '{model}' not found")
|
||||||
|
if model_obj.model_type == ModelType.embedding:
|
||||||
|
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||||
|
|
||||||
|
messages: List[OpenAIMessageParam] = []
|
||||||
|
if previous_response_id:
|
||||||
|
previous_response = await self.get_openai_response(previous_response_id)
|
||||||
|
messages.append(OpenAIAssistantMessageParam(content=previous_response.output[0].content[0].text))
|
||||||
|
|
||||||
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
|
|
||||||
|
chat_response = await self.inference_api.openai_chat_completion(
|
||||||
|
model=model_obj.identifier,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
# type cast to appease mypy
|
||||||
|
chat_response = cast(OpenAIChatCompletion, chat_response)
|
||||||
|
|
||||||
|
output_messages = []
|
||||||
|
for choice in chat_response.choices:
|
||||||
|
output_content = ""
|
||||||
|
if isinstance(choice.message.content, str):
|
||||||
|
output_content = choice.message.content
|
||||||
|
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
output_content = choice.message.content.text
|
||||||
|
# TODO: handle image content
|
||||||
|
output_messages.append(
|
||||||
|
OpenAIResponseOutputMessage(
|
||||||
|
id=f"msg_{uuid.uuid4()}",
|
||||||
|
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = OpenAIResponseObject(
|
||||||
|
created_at=chat_response.created,
|
||||||
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
model=model_obj.identifier,
|
||||||
|
object="response",
|
||||||
|
status="completed",
|
||||||
|
output=output_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
if store:
|
||||||
|
# Store in kvstore
|
||||||
|
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=key,
|
||||||
|
value=response.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
yield OpenAIResponseObjectStream(response=response)
|
||||||
|
|
||||||
|
return async_response()
|
||||||
|
|
||||||
|
return response
|
25
llama_stack/providers/registry/openai_responses.py
Normal file
25
llama_stack/providers/registry/openai_responses.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.openai_responses,
|
||||||
|
provider_type="inline::openai-responses",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.openai_responses",
|
||||||
|
config_class="llama_stack.providers.inline.openai_responses.config.OpenAIResponsesImplConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.models,
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
|
@ -31,4 +31,6 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
|
openai_responses:
|
||||||
|
- inline::openai-responses
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
|
- openai_responses
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -115,6 +116,14 @@ providers:
|
||||||
provider_type: remote::wolfram-alpha
|
provider_type: remote::wolfram-alpha
|
||||||
config:
|
config:
|
||||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||||
|
openai_responses:
|
||||||
|
- provider_id: openai-responses
|
||||||
|
provider_type: inline::openai-responses
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
|
- openai_responses
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -108,6 +109,14 @@ providers:
|
||||||
provider_type: remote::wolfram-alpha
|
provider_type: remote::wolfram-alpha
|
||||||
config:
|
config:
|
||||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||||
|
openai_responses:
|
||||||
|
- provider_id: openai-responses
|
||||||
|
provider_type: inline::openai-responses
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||||
|
|
|
@ -39,6 +39,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"remote::model-context-protocol",
|
"remote::model-context-protocol",
|
||||||
"remote::wolfram-alpha",
|
"remote::wolfram-alpha",
|
||||||
],
|
],
|
||||||
|
"openai_responses": ["inline::openai-responses"],
|
||||||
}
|
}
|
||||||
name = "remote-vllm"
|
name = "remote-vllm"
|
||||||
inference_provider = Provider(
|
inference_provider = Provider(
|
||||||
|
|
5
tests/integration/openai_responses/__init__.py
Normal file
5
tests/integration/openai_responses/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
90
tests/integration/openai_responses/test_openai_responses.py
Normal file
90
tests/integration/openai_responses/test_openai_responses.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from ..test_cases.test_case import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_client(client_with_models):
|
||||||
|
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||||
|
return OpenAI(base_url=base_url, api_key="bar")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"openai:responses:non_streaming_01",
|
||||||
|
"openai:responses:non_streaming_02",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_responses_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
question = tc["question"]
|
||||||
|
expected = tc["expected"]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=question,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
output_text = response.output_text.lower().strip()
|
||||||
|
assert len(output_text) > 0
|
||||||
|
assert expected.lower() in output_text
|
||||||
|
|
||||||
|
retrieved_response = openai_client.responses.retrieve(response_id=response.id)
|
||||||
|
assert retrieved_response.output_text == response.output_text
|
||||||
|
|
||||||
|
next_response = openai_client.responses.create(
|
||||||
|
model=text_model_id, input="Repeat your previous response in all caps.", previous_response_id=response.id
|
||||||
|
)
|
||||||
|
next_output_text = next_response.output_text.strip()
|
||||||
|
assert expected.upper() in next_output_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"openai:responses:streaming_01",
|
||||||
|
"openai:responses:streaming_02",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_responses_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
question = tc["question"]
|
||||||
|
expected = tc["expected"]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=question,
|
||||||
|
stream=True,
|
||||||
|
timeout=120, # Increase timeout to 2 minutes for large conversation history
|
||||||
|
)
|
||||||
|
streamed_content = []
|
||||||
|
response_id = ""
|
||||||
|
for chunk in response:
|
||||||
|
response_id = chunk.response.id
|
||||||
|
streamed_content.append(chunk.response.output_text.strip())
|
||||||
|
|
||||||
|
assert len(streamed_content) > 0
|
||||||
|
assert expected.lower() in "".join(streamed_content).lower()
|
||||||
|
|
||||||
|
retrieved_response = openai_client.responses.retrieve(response_id=response_id)
|
||||||
|
assert retrieved_response.output_text == "".join(streamed_content)
|
||||||
|
|
||||||
|
next_response = openai_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input="Repeat your previous response in all caps.",
|
||||||
|
previous_response_id=response_id,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
next_streamed_content = []
|
||||||
|
for chunk in next_response:
|
||||||
|
next_streamed_content.append(chunk.response.output_text.strip())
|
||||||
|
assert expected.upper() in "".join(next_streamed_content)
|
26
tests/integration/test_cases/openai/responses.json
Normal file
26
tests/integration/test_cases/openai/responses.json
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"non_streaming_01": {
|
||||||
|
"data": {
|
||||||
|
"question": "Which planet do humans live on?",
|
||||||
|
"expected": "Earth"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"non_streaming_02": {
|
||||||
|
"data": {
|
||||||
|
"question": "Which planet has rings around it with a name starting with letter S?",
|
||||||
|
"expected": "Saturn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"streaming_01": {
|
||||||
|
"data": {
|
||||||
|
"question": "What's the name of the Sun in latin?",
|
||||||
|
"expected": "Sol"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"streaming_02": {
|
||||||
|
"data": {
|
||||||
|
"question": "What is the name of the US captial?",
|
||||||
|
"expected": "Washington"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ class TestCase:
|
||||||
_apis = [
|
_apis = [
|
||||||
"inference/chat_completion",
|
"inference/chat_completion",
|
||||||
"inference/completion",
|
"inference/completion",
|
||||||
|
"openai/responses",
|
||||||
]
|
]
|
||||||
_jsonblob = {}
|
_jsonblob = {}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue