mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add list responses API (#2233)
# What does this PR do? This is not part of the official OpenAI API, but we'll use this for the logs UI. In order to support more filtering options, I'm adopting the newly introduced sql store in in place of the kv store. ## Test Plan Added integration/unit tests.
This commit is contained in:
parent
6463ee7633
commit
5844c2da68
47 changed files with 704 additions and 77 deletions
170
docs/_static/llama-stack-spec.html
vendored
170
docs/_static/llama-stack-spec.html
vendored
|
@ -518,6 +518,74 @@
|
|||
}
|
||||
},
|
||||
"/v1/openai/v1/responses": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListOpenAIResponseObject.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListOpenAIResponseObject"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "List all OpenAI responses.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "after",
|
||||
"in": "query",
|
||||
"description": "The ID of the last response to return.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"description": "The number of responses to return.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "model",
|
||||
"in": "query",
|
||||
"description": "The model to filter responses by.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "order",
|
||||
"in": "query",
|
||||
"description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Order"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -10179,6 +10247,108 @@
|
|||
],
|
||||
"title": "ListModelsResponse"
|
||||
},
|
||||
"ListOpenAIResponseObject": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseObjectWithInput"
|
||||
}
|
||||
},
|
||||
"has_more": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"first_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"last_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"object": {
|
||||
"type": "string",
|
||||
"const": "list",
|
||||
"default": "list"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data",
|
||||
"has_more",
|
||||
"first_id",
|
||||
"last_id",
|
||||
"object"
|
||||
],
|
||||
"title": "ListOpenAIResponseObject"
|
||||
},
|
||||
"OpenAIResponseObjectWithInput": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "integer"
|
||||
},
|
||||
"error": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseError"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"object": {
|
||||
"type": "string",
|
||||
"const": "response",
|
||||
"default": "response"
|
||||
},
|
||||
"output": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutput"
|
||||
}
|
||||
},
|
||||
"parallel_tool_calls": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"previous_response_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string"
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number"
|
||||
},
|
||||
"top_p": {
|
||||
"type": "number"
|
||||
},
|
||||
"truncation": {
|
||||
"type": "string"
|
||||
},
|
||||
"user": {
|
||||
"type": "string"
|
||||
},
|
||||
"input": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseInput"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"created_at",
|
||||
"id",
|
||||
"model",
|
||||
"object",
|
||||
"output",
|
||||
"parallel_tool_calls",
|
||||
"status",
|
||||
"input"
|
||||
],
|
||||
"title": "OpenAIResponseObjectWithInput"
|
||||
},
|
||||
"ListProvidersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
121
docs/_static/llama-stack-spec.yaml
vendored
121
docs/_static/llama-stack-spec.yaml
vendored
|
@ -349,6 +349,53 @@ paths:
|
|||
$ref: '#/components/schemas/CreateAgentTurnRequest'
|
||||
required: true
|
||||
/v1/openai/v1/responses:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A ListOpenAIResponseObject.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListOpenAIResponseObject'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Agents
|
||||
description: List all OpenAI responses.
|
||||
parameters:
|
||||
- name: after
|
||||
in: query
|
||||
description: The ID of the last response to return.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: limit
|
||||
in: query
|
||||
description: The number of responses to return.
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
- name: model
|
||||
in: query
|
||||
description: The model to filter responses by.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: order
|
||||
in: query
|
||||
description: >-
|
||||
The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||
required: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/Order'
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -7106,6 +7153,80 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListModelsResponse
|
||||
ListOpenAIResponseObject:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIResponseObjectWithInput'
|
||||
has_more:
|
||||
type: boolean
|
||||
first_id:
|
||||
type: string
|
||||
last_id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: list
|
||||
default: list
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
- has_more
|
||||
- first_id
|
||||
- last_id
|
||||
- object
|
||||
title: ListOpenAIResponseObject
|
||||
OpenAIResponseObjectWithInput:
|
||||
type: object
|
||||
properties:
|
||||
created_at:
|
||||
type: integer
|
||||
error:
|
||||
$ref: '#/components/schemas/OpenAIResponseError'
|
||||
id:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: response
|
||||
default: response
|
||||
output:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIResponseOutput'
|
||||
parallel_tool_calls:
|
||||
type: boolean
|
||||
default: false
|
||||
previous_response_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
temperature:
|
||||
type: number
|
||||
top_p:
|
||||
type: number
|
||||
truncation:
|
||||
type: string
|
||||
user:
|
||||
type: string
|
||||
input:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIResponseInput'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- created_at
|
||||
- id
|
||||
- model
|
||||
- object
|
||||
- output
|
||||
- parallel_tool_calls
|
||||
- status
|
||||
- input
|
||||
title: OpenAIResponseObjectWithInput
|
||||
ListProvidersResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
|
@ -31,6 +31,7 @@ from llama_stack.apis.tools import ToolDef
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
|
@ -611,3 +612,21 @@ class Agents(Protocol):
|
|||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET")
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all OpenAI responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
:param model: The model to filter responses by.
|
||||
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||
:returns: A ListOpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -219,3 +219,17 @@ register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
|||
class OpenAIResponseInputItemList(BaseModel):
|
||||
data: list[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
data: list[OpenAIResponseObjectWithInput]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -11,6 +12,11 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""A generic paginated response that follows a simple format.
|
||||
|
|
|
@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel):
|
|||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class InferenceProvider(Protocol):
|
||||
|
|
|
@ -20,9 +20,11 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Order,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -39,6 +41,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
@ -66,15 +69,17 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl = None
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
self.responses_store = ResponsesStore(self.config.responses_store)
|
||||
await self.responses_store.initialize()
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
self.persistence_store,
|
||||
inference_api=self.inference_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
@ -323,3 +328,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
|
|
@ -10,10 +10,12 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
responses_store: SqlStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
|||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="agents_store.db",
|
||||
)
|
||||
),
|
||||
"responses_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="responses_store.db",
|
||||
),
|
||||
}
|
||||
|
|
|
@ -12,7 +12,9 @@ from typing import Any, cast
|
|||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
|
@ -53,7 +55,7 @@ from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolR
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
@ -169,34 +171,27 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
|||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
persistence_store: KVStore,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
):
|
||||
self.persistence_store = persistence_store
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
self.responses_store = responses_store
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input_items.data
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.response.output)
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
|
@ -216,8 +211,17 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(response_id)
|
||||
return response_with_input.response
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
|
@ -360,15 +364,9 @@ class OpenAIResponsesImpl:
|
|||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=prev_response.model_dump_json(),
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
98
llama_stack/providers/utils/responses/responses_store.py
Normal file
98
llama_stack/providers/utils/responses/responses_store.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = sqlstore_impl(sql_store_config)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"response_object": ColumnType.JSON,
|
||||
"model": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": data["id"],
|
||||
"created_at": data["created_at"],
|
||||
"model": data["model"],
|
||||
"response_object": data,
|
||||
},
|
||||
)
|
||||
|
||||
async def list_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""
|
||||
List responses from the database.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The maximum number of responses to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"openai_responses",
|
||||
where={"model": model} if model else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows]
|
||||
return ListOpenAIResponseObject(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
|
@ -35,6 +35,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/responses_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -38,6 +38,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -37,6 +37,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,6 +50,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -56,6 +56,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -40,6 +40,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -38,6 +38,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -64,6 +64,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,6 +50,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -43,6 +43,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -51,6 +51,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -72,6 +72,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -40,6 +40,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -46,6 +46,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,6 +41,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -74,6 +74,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,6 +45,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -42,6 +42,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
97
tests/integration/agents/test_openai_responses.py
Normal file
97
tests/integration/agents/test_openai_responses.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stream",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"tools",
|
||||
[
|
||||
[],
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city to get the weather for"},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
message = "What's the weather in Tokyo?" + (
|
||||
" YOU MUST USE THE get_weather function to get the weather." if tools else ""
|
||||
)
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": message,
|
||||
}
|
||||
],
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
)
|
||||
if stream:
|
||||
# accumulate the streamed content
|
||||
content = ""
|
||||
response_id = None
|
||||
for chunk in response:
|
||||
if response_id is None:
|
||||
response_id = chunk.response.id
|
||||
if not tools:
|
||||
if chunk.type == "response.completed":
|
||||
response_id = chunk.response.id
|
||||
content = chunk.response.output[0].content[0].text
|
||||
else:
|
||||
response_id = response.id
|
||||
if not tools:
|
||||
content = response.output[0].content[0].text
|
||||
|
||||
# list responses is not available in the SDK
|
||||
url = urljoin(str(client.base_url), "responses")
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {client.api_key}"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()["data"]
|
||||
assert response_id in [r["id"] for r in data]
|
||||
|
||||
# test retrieve response
|
||||
retrieved_response = client.responses.retrieve(response_id)
|
||||
assert retrieved_response.id == response_id
|
||||
assert retrieved_response.model == text_model_id
|
||||
if tools:
|
||||
assert retrieved_response.output[0].type == "function_call"
|
||||
else:
|
||||
assert retrieved_response.output[0].content[0].text == content
|
|
@ -43,6 +43,10 @@ def config(tmp_path):
|
|||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
responses_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
|
@ -16,12 +16,11 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
)
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
|
@ -33,19 +32,12 @@ from llama_stack.apis.inference.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
OpenAIResponsePreviousResponseWithInputItems,
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kvstore():
|
||||
kvstore = AsyncMock(spec=KVStore)
|
||||
return kvstore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
inference_api = AsyncMock()
|
||||
|
@ -65,12 +57,18 @@ def mock_tool_runtime_api():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api):
|
||||
def mock_responses_store():
|
||||
responses_store = AsyncMock(spec=ResponsesStore)
|
||||
return responses_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store):
|
||||
return OpenAIResponsesImpl(
|
||||
persistence_store=mock_kvstore,
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
)
|
||||
|
||||
|
||||
|
@ -100,7 +98,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
stream=False,
|
||||
temperature=0.1,
|
||||
)
|
||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseMessage)
|
||||
|
@ -167,7 +165,7 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
kwargs={"query": "What is the capital of Ireland?"},
|
||||
)
|
||||
|
||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
assert len(result.output) >= 1
|
||||
|
@ -292,8 +290,7 @@ async def test_prepend_previous_response_none(openai_responses_impl):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input")
|
||||
async def test_prepend_previous_response_basic(get_previous_response_with_input, openai_responses_impl):
|
||||
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a basic previous response to a new response."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
|
@ -301,25 +298,21 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input,
|
|||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=[input_item_message])
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObject(
|
||||
previous_response = OpenAIResponseObjectWithInput(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
input=[input_item_message],
|
||||
)
|
||||
previous_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
get_previous_response_with_input.return_value = previous_response
|
||||
mock_responses_store.get_response_object.return_value = previous_response
|
||||
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
|
||||
|
||||
|
@ -336,16 +329,13 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input,
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input")
|
||||
async def test_prepend_previous_response_web_search(get_previous_response_with_input, openai_responses_impl):
|
||||
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a web search previous response to a new response."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=[input_item_message])
|
||||
output_web_search = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id="ws_123",
|
||||
status="completed",
|
||||
|
@ -356,18 +346,15 @@ async def test_prepend_previous_response_web_search(get_previous_response_with_i
|
|||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObject(
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[output_web_search, output_message],
|
||||
status="completed",
|
||||
input=[input_item_message],
|
||||
)
|
||||
previous_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
get_previous_response_with_input.return_value = previous_response
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
|
@ -464,9 +451,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input")
|
||||
async def test_create_openai_response_with_instructions_and_previous_response(
|
||||
get_previous_response_with_input, openai_responses_impl, mock_inference_api
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test prepending both instructions and previous response."""
|
||||
|
||||
|
@ -475,25 +461,21 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
content="Name some towns in Ireland",
|
||||
role="user",
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=[input_item_message])
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Galway, Longford, Sligo",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObject(
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
input=[input_item_message],
|
||||
)
|
||||
previous_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
get_previous_response_with_input.return_value = previous_response
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
@ -511,7 +493,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
sent_messages = call_args.kwargs["messages"]
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
assert sent_messages[0].role == "system"
|
||||
assert sent_messages[0].content == instructions
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/responses_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue