mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 05:24:39 +00:00
feat(tests): make inference_recorder into api_recorder (include tool_invoke) (#3403)
Renames `inference_recorder.py` to `api_recorder.py` and extends it to support recording/replaying tool invocations in addition to inference calls. This allows us to record web-search, etc. tool calls and thereafter apply recordings for `tests/integration/responses` ## Test Plan ``` export OPENAI_API_KEY=... export TAVILY_SEARCH_API_KEY=... ./scripts/integration-tests.sh --stack-config ci-tests \ --suite responses --inference-mode record-if-missing ```
This commit is contained in:
parent
26fd5dbd34
commit
f50ce11a3b
284 changed files with 296191 additions and 631 deletions
|
@ -108,7 +108,7 @@ class OpenAIResponsesImpl:
|
|||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
|
|
|
@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content(
|
|||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
|
@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages(
|
|||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
# Skip user messages that duplicate the last user message in previous_messages
|
||||
# This handles cases where input includes context for function_call_outputs
|
||||
if previous_messages and input_item.role == "user":
|
||||
last_user_msg = None
|
||||
for msg in reversed(previous_messages):
|
||||
if isinstance(msg, OpenAIUserMessageParam):
|
||||
last_user_msg = msg
|
||||
break
|
||||
if last_user_msg:
|
||||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||
for call_id in list(tool_call_results.keys()):
|
||||
if call_id in previous_call_ids:
|
||||
# Valid: this output references a call from previous messages
|
||||
# Add the tool message
|
||||
messages.append(tool_call_results[call_id])
|
||||
del tool_call_results[call_id]
|
||||
|
||||
# If still have unpaired outputs, error
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||
"""Extract all tool_call IDs from messages."""
|
||||
call_ids = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# tool_call is a Pydantic model, use attribute access
|
||||
call_ids.add(tool_call.id)
|
||||
return call_ids
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
|
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
logger.warning(
|
||||
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||
)
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Creates a vector store."""
|
||||
created_at = int(time.time())
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
created_at = int(time.time())
|
||||
batch_id = f"batch_{uuid.uuid4()}"
|
||||
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
|
||||
# File batches expire after 7 days
|
||||
expires_at = created_at + (7 * 24 * 60 * 60)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue