mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 03:12:24 +00:00
Merge branch 'main' into implement-search-for-PGVector
This commit is contained in:
commit
4c03cddf6f
176 changed files with 8344 additions and 734 deletions
91
tests/integration/batches/test_batches_idempotency.py
Normal file
91
tests/integration/batches/test_batches_idempotency.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Integration tests for batch idempotency functionality using the OpenAI client library.
|
||||
|
||||
This module tests the idempotency feature in the batches API using the OpenAI-compatible
|
||||
client interface. These tests verify that the idempotency key (idempotency_key) works correctly
|
||||
in a real client-server environment.
|
||||
|
||||
Test Categories:
|
||||
1. Successful Idempotency: Same key returns same batch with identical parameters
|
||||
- test_idempotent_batch_creation_successful: Verifies that requests with the same
|
||||
idempotency key return identical batches, even with different metadata order
|
||||
|
||||
2. Conflict Detection: Same key with conflicting parameters raises HTTP 409 errors
|
||||
- test_idempotency_conflict_with_different_params: Verifies that reusing an idempotency key
|
||||
with truly conflicting parameters (both file ID and metadata values) raises ConflictError
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from openai import ConflictError
|
||||
|
||||
|
||||
class TestBatchesIdempotencyIntegration:
|
||||
"""Integration tests for batch idempotency using OpenAI client."""
|
||||
|
||||
def test_idempotent_batch_creation_successful(self, openai_client):
|
||||
"""Test that identical requests with same idempotency key return the same batch."""
|
||||
batch1 = openai_client.batches.create(
|
||||
input_file_id="bogus-id",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"test_type": "idempotency_success",
|
||||
"purpose": "integration_test",
|
||||
},
|
||||
extra_body={"idempotency_key": "test-idempotency-token-1"},
|
||||
)
|
||||
|
||||
# sleep to ensure different timestamps
|
||||
time.sleep(1)
|
||||
|
||||
batch2 = openai_client.batches.create(
|
||||
input_file_id="bogus-id",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"purpose": "integration_test",
|
||||
"test_type": "idempotency_success",
|
||||
}, # Different order
|
||||
extra_body={"idempotency_key": "test-idempotency-token-1"},
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.endpoint == batch2.endpoint
|
||||
assert batch1.completion_window == batch2.completion_window
|
||||
assert batch1.metadata == batch2.metadata
|
||||
assert batch1.created_at == batch2.created_at
|
||||
|
||||
def test_idempotency_conflict_with_different_params(self, openai_client):
|
||||
"""Test that using same idempotency key with different params raises conflict error."""
|
||||
batch1 = openai_client.batches.create(
|
||||
input_file_id="bogus-id-1",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test_type": "conflict_test_1"},
|
||||
extra_body={"idempotency_key": "conflict-token"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConflictError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id="bogus-id-2", # Different file ID
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test_type": "conflict_test_2"}, # Different metadata
|
||||
extra_body={"idempotency_key": "conflict-token"}, # Same token
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "conflict" in str(exc_info.value).lower()
|
||||
|
||||
retrieved_batch = openai_client.batches.retrieve(batch1.id)
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert retrieved_batch.input_file_id == "bogus-id-1"
|
||||
|
|
@ -256,9 +256,6 @@ def instantiate_llama_stack_client(session):
|
|||
provider_data=get_provider_data(),
|
||||
skip_logger_removal=True,
|
||||
)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
#
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import EmbeddingsResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
|
|
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
|
|||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
DUMMY_STRING = "hello"
|
||||
DUMMY_STRING2 = "world"
|
||||
|
|
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError):
|
||||
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
|
||||
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
|
||||
error_type = (
|
||||
OpenAIBadRequestError
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
|
||||
else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_LONG_TEXT],
|
||||
|
|
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError):
|
||||
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
|
|
|
|||
|
|
@ -113,8 +113,6 @@ def openai_client(base_url, api_key, provider):
|
|||
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
||||
config = parts[1]
|
||||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
return client
|
||||
|
||||
return OpenAI(
|
||||
|
|
|
|||
|
|
@ -260,6 +260,94 @@ def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
|||
assert response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=case.tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
inputs = []
|
||||
inputs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": case.input,
|
||||
}
|
||||
)
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": "It is raining.",
|
||||
"call_id": response.output[0].call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
|
||||
|
||||
def test_response_function_call_ordering_2(compat_client, text_model_id):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
]
|
||||
inputs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Is the weather better in San Francisco or Los Angeles?",
|
||||
}
|
||||
]
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
inputs.append(output)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
weather = "It is raining."
|
||||
if "Los Angeles" in output.arguments:
|
||||
weather = "It is cloudy."
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": weather,
|
||||
"call_id": output.call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert "Los Angeles" in response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
|
|
|
|||
Binary file not shown.
39
tests/integration/recordings/responses/390f0c7dac96.json
Normal file
39
tests/integration/recordings/responses/390f0c7dac96.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-11T15:51:18.170868Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 5240614083,
|
||||
"load_duration": 9823416,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 21000000,
|
||||
"eval_count": 310,
|
||||
"eval_duration": 5209000000,
|
||||
"response": "This is the start of a test. I'll provide some sample data and you can try to generate metrics based on it.\n\n**Data:**\n\nLet's say we have a dataset of user interactions with an e-commerce website. The data includes:\n\n| User ID | Product Name | Purchase Date | Quantity | Price |\n| --- | --- | --- | --- | --- |\n| 1 | iPhone 13 | 2022-01-01 | 2 | 999.99 |\n| 1 | MacBook Air | 2022-01-05 | 1 | 1299.99 |\n| 2 | Samsung TV | 2022-01-10 | 3 | 899.99 |\n| 3 | iPhone 13 | 2022-01-15 | 1 | 999.99 |\n| 4 | MacBook Pro | 2022-01-20 | 2 | 1799.99 |\n\n**Task:**\n\nYour task is to generate the following metrics based on this data:\n\n1. Average order value (AOV)\n2. Conversion rate\n3. Average revenue per user (ARPU)\n4. Customer lifetime value (CLV)\n\nPlease provide your answers in a format like this:\n\n| Metric | Value |\n| --- | --- |\n| AOV | 1234.56 |\n| Conversion Rate | 0.25 |\n| ARPU | 1000.00 |\n| CLV | 5000.00 |\n\nGo ahead and generate the metrics!",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
56
tests/integration/recordings/responses/4de6877d86fa.json
Normal file
56
tests/integration/recordings/responses/4de6877d86fa.json
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 0"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-843",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I don't have any information about an \"OpenAI test 0\". It's possible that you may be referring to a specific experiment or task being performed by OpenAI, but without more context, I can only speculate.\n\nHowever, I can tell you that OpenAI is a research organization that has been involved in various projects and tests related to artificial intelligence. If you could provide more context or clarify what you're referring to, I may be able to help further.\n\nIf you're looking for general information about OpenAI, I can try to provide some background on the organization:\n\nOpenAI is a non-profit research organization that was founded in 2015 with the goal of developing and applying advanced artificial intelligence to benefit humanity. The organization has made significant contributions to the field of AI, including the development of the popular language model, ChatGPT.\n\nIf you could provide more context or clarify what you're looking for, I'll do my best to assist you.",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891518,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 194,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 224,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
56
tests/integration/recordings/responses/5db0c44c83a4.json
Normal file
56
tests/integration/recordings/responses/5db0c44c83a4.json
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 1"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-726",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm ready to help with the test. What language would you like to use? Would you like to have a conversation, ask questions, or take a specific type of task?",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891519,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 67,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
56
tests/integration/recordings/responses/6cb0285a7638.json
Normal file
56
tests/integration/recordings/responses/6cb0285a7638.json
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 4"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-581",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm ready to help. What would you like to test? We could try a variety of things, such as:\n\n1. Conversational dialogue\n2. Language understanding\n3. Common sense reasoning\n4. Joke or pun generation\n5. Trivia or knowledge-based questions\n6. Creative writing or storytelling\n7. Summarization or paraphrasing\n\nLet me know which area you'd like to test, or suggest something else that's on your mind!",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891527,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 96,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 126,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
39
tests/integration/recordings/responses/7bcb0f86c91b.json
Normal file
39
tests/integration/recordings/responses/7bcb0f86c91b.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-11T15:51:12.918723Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 8868987792,
|
||||
"load_duration": 2793275292,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 250000000,
|
||||
"eval_count": 344,
|
||||
"eval_duration": 5823000000,
|
||||
"response": "Here are some common test metrics used to evaluate the performance of a system:\n\n1. **Accuracy**: The proportion of correct predictions or classifications out of total predictions made.\n2. **Precision**: The ratio of true positives (correctly predicted instances) to the sum of true positives and false positives (incorrectly predicted instances).\n3. **Recall**: The ratio of true positives to the sum of true positives and false negatives (missed instances).\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: The average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: The average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: The square root of the mean of the squared percentage differences between predicted and actual values.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well a model fits the data, with higher values indicating better fit.\n9. **Mean Absolute Percentage Error (MAPE)**: The average absolute percentage difference between predicted and actual values.\n10. **Normalized Mean Squared Error (NMSE)**: Similar to MSE, but normalized by the mean of the actual values.\n\nThese metrics can be used for various types of data, including:\n\n* Regression problems (e.g., predicting continuous values)\n* Classification problems (e.g., predicting categorical labels)\n* Time series forecasting\n* Clustering and dimensionality reduction\n\nWhen choosing a metric, consider the specific problem you're trying to solve, the type of data, and the desired level of precision.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
56
tests/integration/recordings/responses/bf79a89cc37f.json
Normal file
56
tests/integration/recordings/responses/bf79a89cc37f.json
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 3"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-48",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm happy to help, but it seems you want me to engage in a basic conversation as OpenAI's new chat model, right? I can do that!\n\nHere's my response:\n\nHello! How are you today? Is there something specific on your mind that you'd like to talk about or any particular topic you'd like to explore together?\n\nWhat is it that you're curious about?",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891524,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 80,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 110,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
39
tests/integration/recordings/responses/c31a86ea6c58.json
Normal file
39
tests/integration/recordings/responses/c31a86ea6c58.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:06.703788Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 2722294000,
|
||||
"load_duration": 9736083,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 113000000,
|
||||
"eval_count": 324,
|
||||
"eval_duration": 2598000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n1. **Accuracy**: The proportion of correct predictions made by the model.\n2. **Precision**: The ratio of true positives (correctly predicted instances) to total positive predictions.\n3. **Recall**: The ratio of true positives to the sum of true positives and false negatives (missed instances).\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: The average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: The average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: A variation of MSE that expresses the error as a percentage.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well the model explains the variance in the data.\n9. **Mean Absolute Percentage Error (MAPE)**: The average absolute percentage difference between predicted and actual values.\n10. **Mean Squared Logarithmic Error (MSLE)**: A variation of MSE that is more suitable for skewed distributions.\n\nThese metrics can be used to evaluate different aspects of a system's performance, such as:\n\n* Classification models: accuracy, precision, recall, F1-score\n* Regression models: MSE, MAE, RMSPE, R2, MSLE\n* Time series forecasting: MAPE, RMSPE\n\nNote that the choice of metric depends on the specific problem and data.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
56
tests/integration/recordings/responses/dc8120cf0774.json
Normal file
56
tests/integration/recordings/responses/dc8120cf0774.json
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 2"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-516",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm happy to help with your question or task. Please go ahead and ask me anything, and I'll do my best to assist you.\n\nNote: I'll be using the latest version of my knowledge cutoff, which is December 2023.\n\nAlso, please keep in mind that I'm a large language model, I can provide information on a broad range of topics, including science, history, technology, culture, and more. However, my ability to understand and respond to specific questions or requests may be limited by the data I've been trained on.",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891522,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 113,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 143,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
39
tests/integration/recordings/responses/f6857bcea729.json
Normal file
39
tests/integration/recordings/responses/f6857bcea729.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 2<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:13.082679Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 2606245291,
|
||||
"load_duration": 9979708,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 23000000,
|
||||
"eval_count": 321,
|
||||
"eval_duration": 2572000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n1. **Accuracy**: Measures how close the predicted values are to the actual values.\n2. **Precision**: Measures the proportion of true positives among all positive predictions made by the model.\n3. **Recall**: Measures the proportion of true positives among all actual positive instances.\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: Measures the average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: Measures the average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: A variation of MSE that expresses errors as a percentage of the actual value.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well the model explains the variance in the data.\n9. **Mean Absolute Percentage Error (MAPE)**: Measures the average absolute percentage difference between predicted and actual values.\n10. **Mean Squared Logarithmic Error (MSLE)**: A variation of MSE that is more suitable for skewed distributions.\n\nThese metrics can be used to evaluate different aspects of a system's performance, such as:\n\n* Classification models: accuracy, precision, recall, F1-score\n* Regression models: MSE, MAE, RMSPE, R2\n* Time series forecasting: MAPE, MSLE\n\nNote that the choice of metric depends on the specific problem and data.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
39
tests/integration/recordings/responses/f80b99430f7e.json
Normal file
39
tests/integration/recordings/responses/f80b99430f7e.json
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:10.465932Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 3745686709,
|
||||
"load_duration": 9734584,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 23000000,
|
||||
"eval_count": 457,
|
||||
"eval_duration": 3712000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n**Primary Metrics**\n\n1. **Response Time**: The time it takes for the system to respond to a request.\n2. **Throughput**: The number of requests processed by the system per unit time (e.g., requests per second).\n3. **Error Rate**: The percentage of requests that result in an error.\n\n**Secondary Metrics**\n\n1. **Average Response Time**: The average response time for all requests.\n2. **Median Response Time**: The middle value of the response times, used to detect outliers.\n3. **99th Percentile Response Time**: The response time at which 99% of requests are completed within this time.\n4. **Request Latency**: The difference between the request arrival time and the response time.\n\n**User Experience Metrics**\n\n1. **User Satisfaction (USAT)**: Measured through surveys or feedback forms to gauge user satisfaction with the system's performance.\n2. **First Response Time**: The time it takes for a user to receive their first response from the system.\n3. **Time Spent in System**: The total amount of time a user spends interacting with the system.\n\n**System Resource Metrics**\n\n1. **CPU Utilization**: The percentage of CPU resources being used by the system.\n2. **Memory Usage**: The amount of memory being used by the system.\n3. **Disk I/O Wait Time**: The average time spent waiting for disk I/O operations to complete.\n\n**Security Metrics**\n\n1. **Authentication Success Rate**: The percentage of successful authentication attempts.\n2. **Authorization Success Rate**: The percentage of successful authorization attempts.\n3. **Error Rate (Security)**: The percentage of security-related errors.\n\n**Other Metrics**\n\n1. **Page Load Time**: The time it takes for a page to load.\n2. **Click-Through Rate (CTR)**: The percentage of users who click on a link or button after seeing an ad or notification.\n3. **Conversion Rate**: The percentage of users who complete a desired action (e.g., fill out a form, make a purchase).\n\nThese metrics can be used to evaluate the performance and effectiveness of various aspects of your system, from user experience to security and resource utilization.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
||||
209
tests/integration/telemetry/test_telemetry_metrics.py
Normal file
209
tests/integration/telemetry/test_telemetry_metrics.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_telemetry_metrics_data(openai_client, client_with_models, text_model_id):
|
||||
"""Setup fixture that creates telemetry metrics data before tests run."""
|
||||
|
||||
# Skip OpenAI tests if running in library mode
|
||||
if not hasattr(client_with_models, "base_url"):
|
||||
pytest.skip("OpenAI client tests not supported with library client")
|
||||
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
total_tokens = []
|
||||
|
||||
# Create OpenAI completions to generate metrics using the proper OpenAI client
|
||||
for i in range(5):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": f"OpenAI test {i}"}],
|
||||
stream=False,
|
||||
)
|
||||
prompt_tokens.append(response.usage.prompt_tokens)
|
||||
completion_tokens.append(response.usage.completion_tokens)
|
||||
total_tokens.append(response.usage.total_tokens)
|
||||
|
||||
# Wait for metrics to be logged
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 30:
|
||||
try:
|
||||
# Try to query metrics to see if they're available
|
||||
metrics_response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="completion_tokens",
|
||||
start_time=int((datetime.now(UTC) - timedelta(minutes=5)).timestamp()),
|
||||
)
|
||||
if len(metrics_response[0].values) > 0:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
|
||||
# Wait additional time to ensure all metrics are processed
|
||||
time.sleep(5)
|
||||
|
||||
# Return the token lists for use in tests
|
||||
return {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_prompt_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that prompt_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "prompt_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["prompt_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_completion_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that completion_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="completion_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "completion_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["completion_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_total_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that total_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "total_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["total_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_time_range(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with time range."""
|
||||
end_time = int(datetime.now(UTC).timestamp())
|
||||
start_time = end_time - 600 # 10 minutes ago
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "prompt_tokens"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_label_matchers(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with label matchers."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
label_matchers=[{"name": "model_id", "value": text_model_id, "operator": "="}],
|
||||
)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_nonexistent_metric(llama_stack_client):
|
||||
"""Test that querying a nonexistent metric returns empty data."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="nonexistent_metric",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list), "Should return an empty list for nonexistent metric"
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_granularity(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with different granularity levels."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
# Test hourly granularity
|
||||
hourly_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity="1h",
|
||||
)
|
||||
|
||||
# Test daily granularity
|
||||
daily_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity="1d",
|
||||
)
|
||||
|
||||
# Test no granularity (raw data points)
|
||||
raw_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity=None,
|
||||
)
|
||||
|
||||
# All should return valid data
|
||||
assert isinstance(hourly_response[0].values, list), "Hourly granularity should return data"
|
||||
assert isinstance(daily_response[0].values, list), "Daily granularity should return data"
|
||||
assert isinstance(raw_response[0].values, list), "No granularity should return data"
|
||||
|
||||
# Verify that different granularities produce different aggregation levels
|
||||
# (The exact number depends on data distribution, but they should be queryable)
|
||||
assert len(hourly_response[0].values) >= 0, "Hourly granularity should be queryable"
|
||||
assert len(daily_response[0].values) >= 0, "Daily granularity should be queryable"
|
||||
assert len(raw_response[0].values) >= 0, "No granularity should be queryable"
|
||||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
|
@ -133,7 +132,6 @@ class TestInferenceRecording:
|
|||
# Test directory creation
|
||||
assert storage.test_dir.exists()
|
||||
assert storage.responses_dir.exists()
|
||||
assert storage.db_path.exists()
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
|
|
@ -147,15 +145,6 @@ class TestInferenceRecording:
|
|||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify SQLite record
|
||||
with sqlite3.connect(storage.db_path) as conn:
|
||||
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == request_hash # request_hash
|
||||
assert result[2] == "/v1/chat/completions" # endpoint
|
||||
assert result[3] == "llama3.2:3b" # model
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
|
|
@ -185,10 +174,7 @@ class TestInferenceRecording:
|
|||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
with sqlite3.connect(storage.db_path) as conn:
|
||||
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
|
||||
|
||||
assert recordings == 1
|
||||
assert storage.responses_dir.exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
|
|
|||
|
|
@ -5,86 +5,121 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
||||
Unit tests for LlamaStackAsLibraryClient automatic initialization.
|
||||
|
||||
These tests ensure that users get proper error messages when they forget to call
|
||||
initialize() on the library client, preventing AttributeError regressions.
|
||||
These tests ensure that the library client is automatically initialized
|
||||
and ready to use immediately after construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.library_client import (
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
from llama_stack.core.server.routes import RouteImpls
|
||||
|
||||
|
||||
class TestLlamaStackAsLibraryClientInitialization:
|
||||
"""Test proper error handling for uninitialized library clients."""
|
||||
class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||
"""Test automatic initialization of library clients."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
lambda client: next(
|
||||
client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
||||
)
|
||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
def test_sync_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that sync client is automatically initialized after construction."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
api_call(client)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create"],
|
||||
)
|
||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await api_call(client)
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
assert client.async_client.route_impls is not None
|
||||
|
||||
async def test_async_client_streaming_error_without_initialization(self):
|
||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
async def test_async_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that async client can be initialized and works properly."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
stream = await client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
await anext(stream)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
def test_route_impls_initialized_to_none(self):
|
||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
||||
# Test sync client
|
||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
||||
assert sync_client.async_client.route_impls is None
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
# Test async client directly
|
||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
assert async_client.route_impls is None
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
# Initialize the client
|
||||
result = await client.initialize()
|
||||
assert result is True
|
||||
assert client.route_impls is not None
|
||||
|
||||
def test_initialize_method_backward_compatibility(self, monkeypatch):
|
||||
"""Test that initialize() method still works for backward compatibility."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result = client.initialize()
|
||||
assert result is None
|
||||
|
||||
result2 = client.initialize()
|
||||
assert result2 is None
|
||||
|
||||
async def test_async_initialize_method_idempotent(self, monkeypatch):
|
||||
"""Test that async initialize() method can be called multiple times safely."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result1 = await client.initialize()
|
||||
assert result1 is True
|
||||
|
||||
result2 = await client.initialize()
|
||||
assert result2 is True
|
||||
|
||||
def test_route_impls_automatically_set(self, monkeypatch):
|
||||
"""Test that route_impls is automatically set during construction."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
assert sync_client.async_client.route_impls is not None
|
||||
|
|
|
|||
|
|
@ -115,18 +115,27 @@ class TestConvertResponseInputToChatMessages:
|
|||
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIToolMessageParam)
|
||||
assert result[0].content == "Tool output"
|
||||
assert result[0].tool_call_id == "call_123"
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "Tool output"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
|
|
@ -146,6 +155,47 @@ class TestConvertResponseInputToChatMessages:
|
|||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
async def test_convert_function_call_ordering(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function_a",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function_b",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="AAA",
|
||||
call_id="call_123",
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="BBB",
|
||||
call_id="call_456",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "AAA"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||
assert len(result[2].tool_calls) == 1
|
||||
assert result[2].tool_calls[0].id == "call_456"
|
||||
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||
assert result[3].content == "BBB"
|
||||
assert result[3].tool_call_id == "call_456"
|
||||
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
|
|
|
|||
54
tests/unit/providers/batches/conftest.py
Normal file
54
tests/unit/providers/batches/conftest.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Shared fixtures for batches provider unit tests."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def provider():
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data():
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
||||
|
|
@ -54,60 +54,17 @@ dependencies like inference, files, and models APIs.
|
|||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
"""Test the reference implementation of the Batches API."""
|
||||
|
||||
@pytest.fixture
|
||||
async def provider(self):
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data(self):
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
||||
|
||||
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||
"""
|
||||
Helper function to validate batch object structure and field types.
|
||||
|
|
|
|||
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tests for idempotency functionality in the reference batches provider.
|
||||
|
||||
This module tests the optional idempotency feature that allows clients to provide
|
||||
an idempotency key (idempotency_key) to ensure that repeated requests with the same key
|
||||
and parameters return the same batch, while requests with the same key but different
|
||||
parameters result in a conflict error.
|
||||
|
||||
Test Categories:
|
||||
1. Core Idempotency: Same parameters with same key return same batch
|
||||
2. Parameter Independence: Different parameters without keys create different batches
|
||||
3. Conflict Detection: Same key with different parameters raises ConflictError
|
||||
|
||||
Tests by Category:
|
||||
|
||||
1. Core Idempotency:
|
||||
- test_idempotent_batch_creation_same_params
|
||||
- test_idempotent_batch_creation_metadata_order_independence
|
||||
|
||||
2. Parameter Independence:
|
||||
- test_non_idempotent_behavior_without_key
|
||||
- test_different_idempotency_keys_create_different_batches
|
||||
|
||||
3. Conflict Detection:
|
||||
- test_same_idempotency_key_different_params_conflict (parametrized: input_file_id, metadata values, metadata None vs {})
|
||||
|
||||
Key Behaviors Tested:
|
||||
- Idempotent batch creation when idempotency_key provided with identical parameters
|
||||
- Metadata order independence for consistent batch ID generation
|
||||
- Non-idempotent behavior when no idempotency_key provided (random UUIDs)
|
||||
- Conflict detection for parameter mismatches with same idempotency key
|
||||
- Deterministic ID generation based solely on idempotency key
|
||||
- Proper error handling with detailed conflict messages including key and error codes
|
||||
- Protection against idempotency key reuse with different request parameters
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
"""Test suite for idempotency functionality in the reference implementation."""
|
||||
|
||||
async def test_idempotent_batch_creation_same_params(self, provider, sample_batch_data):
|
||||
"""Test that creating batches with identical parameters returns the same batch when idempotency_key is provided."""
|
||||
|
||||
del sample_batch_data["metadata"]
|
||||
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
# sleep for 1 second to allow created_at timestamps to be different
|
||||
await asyncio.sleep(1)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.metadata == batch2.metadata
|
||||
assert batch1.created_at == batch2.created_at
|
||||
|
||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||
"""Test that different idempotency keys create different batches even with same params."""
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-A",
|
||||
)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-B",
|
||||
)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
|
||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
batch2 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.endpoint == batch2.endpoint
|
||||
assert batch1.completion_window == batch2.completion_window
|
||||
assert batch1.metadata == batch2.metadata
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,first_value,second_value",
|
||||
[
|
||||
("input_file_id", "file_001", "file_002"),
|
||||
("metadata", {"test": "value1"}, {"test": "value2"}),
|
||||
("metadata", None, {}),
|
||||
],
|
||||
)
|
||||
async def test_same_idempotency_key_different_params_conflict(
|
||||
self, provider, sample_batch_data, param_name, first_value, second_value
|
||||
):
|
||||
"""Test that same idempotency_key with different parameters raises conflict error."""
|
||||
sample_batch_data["idempotency_key"] = "same-token"
|
||||
|
||||
sample_batch_data[param_name] = first_value
|
||||
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||
sample_batch_data[param_name] = second_value
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert getattr(retrieved_batch, param_name) == first_value
|
||||
251
tests/unit/providers/files/test_s3_files.py
Normal file
251
tests/unit/providers/files/test_s3_files.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
"""Test suite for S3 Files implementation."""
|
||||
|
||||
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
|
||||
"""Test successful file upload."""
|
||||
sample_text_file.filename = "test_upload_file"
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.filename == sample_text_file.filename
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
assert result.id.startswith("file-")
|
||||
|
||||
# Verify file exists in S3 backend
|
||||
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
async def test_list_files_empty(self, s3_provider):
|
||||
"""Test listing files when no files exist."""
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 0
|
||||
assert not result.has_more
|
||||
assert result.first_id == ""
|
||||
assert result.last_id == ""
|
||||
|
||||
async def test_retrieve_file(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file metadata."""
|
||||
sample_text_file.filename = "test_retrieve_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
assert retrieved.id == uploaded.id
|
||||
assert retrieved.filename == uploaded.filename
|
||||
assert retrieved.purpose == uploaded.purpose
|
||||
assert retrieved.bytes == uploaded.bytes
|
||||
|
||||
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file content."""
|
||||
sample_text_file.filename = "test_retrieve_file_content"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
|
||||
assert response.body == sample_text_file.content
|
||||
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
|
||||
|
||||
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test deleting a file."""
|
||||
sample_text_file.filename = "test_delete_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
delete_response = await s3_provider.openai_delete_file(uploaded.id)
|
||||
|
||||
assert delete_response.id == uploaded.id
|
||||
assert delete_response.deleted is True
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
# Verify file is gone from S3 backend
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 2
|
||||
file_ids = {f.id for f in result.data}
|
||||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == file1.id
|
||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
|
||||
async def test_nonexistent_file_retrieval(self, s3_provider):
|
||||
"""Test retrieving a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_content_retrieval(self, s3_provider):
|
||||
"""Test retrieving content of a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_deletion(self, s3_provider):
|
||||
"""Test deleting a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
|
||||
"""Test uploading a file without a filename uses the fallback."""
|
||||
del sample_text_file.filename
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(result.id)
|
||||
assert retrieved.filename == result.filename
|
||||
|
||||
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
|
||||
sample_text_file.filename = "test_orphaned_metadata"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
# Directly delete the S3 object from the backend
|
||||
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
|
||||
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
assert uploaded.id in str(exc_info).lower()
|
||||
|
||||
listed_files = await s3_provider.openai_list_files()
|
||||
assert uploaded.id not in [file.id for file in listed_files.data]
|
||||
|
||||
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test that put_object failure results in exception and no orphaned metadata."""
|
||||
sample_text_file.filename = "test_s3_put_object_failure"
|
||||
|
||||
def failing_put_object(*args, **kwargs):
|
||||
raise ClientError(
|
||||
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
|
||||
)
|
||||
|
||||
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
|
||||
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
files_list = await s3_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
||||
105
tests/unit/server/test_cors.py
Normal file
105
tests/unit/server/test_cors.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
||||
|
||||
|
||||
def test_cors_config_defaults():
|
||||
config = CORSConfig()
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex is None
|
||||
assert config.allow_methods == ["OPTIONS"]
|
||||
assert config.allow_headers == []
|
||||
assert config.allow_credentials is False
|
||||
assert config.expose_headers == []
|
||||
assert config.max_age == 600
|
||||
|
||||
|
||||
def test_cors_config_explicit_config():
|
||||
config = CORSConfig(
|
||||
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
||||
)
|
||||
|
||||
assert config.allow_origins == ["https://example.com"]
|
||||
assert config.allow_credentials is True
|
||||
assert config.max_age == 3600
|
||||
assert config.allow_methods == ["GET", "POST"]
|
||||
|
||||
|
||||
def test_cors_config_regex():
|
||||
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||
|
||||
|
||||
def test_cors_config_wildcard_credentials_error():
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
||||
|
||||
|
||||
def test_process_cors_config_false():
|
||||
result = process_cors_config(False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_process_cors_config_true():
|
||||
result = process_cors_config(True)
|
||||
|
||||
assert isinstance(result, CORSConfig)
|
||||
assert result.allow_origins == []
|
||||
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||
assert result.allow_credentials is False
|
||||
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
for method in expected_methods:
|
||||
assert method in result.allow_methods
|
||||
|
||||
|
||||
def test_process_cors_config_passthrough():
|
||||
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
|
||||
result = process_cors_config(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_process_cors_config_invalid_type():
|
||||
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
|
||||
process_cors_config("invalid")
|
||||
|
||||
|
||||
def test_cors_config_model_dump():
|
||||
cors_config = CORSConfig(
|
||||
allow_origins=["https://example.com"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=True,
|
||||
max_age=3600,
|
||||
)
|
||||
|
||||
config_dict = cors_config.model_dump()
|
||||
|
||||
assert config_dict["allow_origins"] == ["https://example.com"]
|
||||
assert config_dict["allow_methods"] == ["GET", "POST"]
|
||||
assert config_dict["allow_headers"] == ["Content-Type"]
|
||||
assert config_dict["allow_credentials"] is True
|
||||
assert config_dict["max_age"] == 3600
|
||||
|
||||
expected_keys = {
|
||||
"allow_origins",
|
||||
"allow_origin_regex",
|
||||
"allow_methods",
|
||||
"allow_headers",
|
||||
"allow_credentials",
|
||||
"expose_headers",
|
||||
"max_age",
|
||||
}
|
||||
assert set(config_dict.keys()) == expected_keys
|
||||
|
|
@ -88,3 +88,10 @@ def test_nested_structures(setup_env_vars):
|
|||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
||||
|
||||
def test_explicit_strings_preserved(setup_env_vars):
|
||||
# Explicit strings that look like numbers/booleans should remain strings
|
||||
data = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
expected = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue