Merge branch 'main' into remove-deprecated-embeddings

This commit is contained in:
Matthew Farrellee 2025-09-27 15:01:32 -04:00
commit 5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions

View file

@ -16,9 +16,11 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import ListToolsResponse, Tool, ToolGroups, ToolParameter, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@ -75,11 +77,11 @@ def sample_agent_config():
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"name": "string",
"description": "string",
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
@ -226,3 +228,83 @@ async def test_delete_agent(agents_impl, sample_agent_config):
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
data=[
Tool(
identifier="story_maker",
provider_id="model-context-protocol",
type=ResourceType.tool,
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
parameters=[
ToolParameter(
name="story_title",
parameter_type="string",
description="Title of the story",
required=True,
title="Story Title",
),
ToolParameter(
name="input_words",
parameter_type="array",
description="Input words",
required=False,
items={"type": "string"},
title="Input Words",
default=[],
),
],
)
]
)
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
# Get an instance of ChatAgent
chat_agent = await agents_impl._get_agent_impl(agent_id)
assert chat_agent is not None
assert isinstance(chat_agent, ChatAgent)
# Initialize tool definitions
await chat_agent._initialize_tools()
assert len(chat_agent.tool_defs) == 2
# Verify the first tool, which is a client tool
first_tool = chat_agent.tool_defs[0]
assert first_tool.tool_name == "client_tool"
assert first_tool.description == "Client Tool"
# Verify the second tool, which is an MCP tool that has an array-type property
second_tool = chat_agent.tool_defs[1]
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
parameters = second_tool.parameters
assert len(parameters) == 2
# Verify a string property
story_title = parameters.get("story_title")
assert story_title is not None
assert story_title.param_type == "string"
assert story_title.description == "Title of the story"
assert story_title.required
assert story_title.items is None
assert story_title.title == "Story Title"
assert story_title.default is None
# Verify an array property
input_words = parameters.get("input_words")
assert input_words is not None
assert input_words.param_type == "array"
assert input_words.description == "Input words"
assert not input_words.required
assert input_words.items is not None
assert len(input_words.items) == 1
assert input_words.items.get("type") == "string"
assert input_words.title == "Input Words"
assert input_words.default == []

View file

@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
async def test_validate_input_missing_parameters_chat_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
assert errors[0].message == error_message
assert errors[0].param == param_path
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
],
)
async def test_validate_input_missing_parameters_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for text completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/completions",
"body": {"model": "test-model", "prompt": "Hello"},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import patch
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class TestBedrockBaseConfig:
def test_defaults_work_without_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
config = BedrockBaseConfig()
# Basic creds should be None
assert config.aws_access_key_id is None
assert config.aws_secret_access_key is None
assert config.region_name is None
# Timeouts get defaults
assert config.connect_timeout == 60.0
assert config.read_timeout == 60.0
assert config.session_ttl == 3600
def test_env_vars_get_picked_up(self):
env_vars = {
"AWS_ACCESS_KEY_ID": "AKIATEST123",
"AWS_SECRET_ACCESS_KEY": "secret123",
"AWS_DEFAULT_REGION": "us-west-2",
"AWS_MAX_ATTEMPTS": "5",
"AWS_RETRY_MODE": "adaptive",
"AWS_CONNECT_TIMEOUT": "30",
}
with patch.dict(os.environ, env_vars, clear=True):
config = BedrockBaseConfig()
assert config.aws_access_key_id == "AKIATEST123"
assert config.aws_secret_access_key == "secret123"
assert config.region_name == "us-west-2"
assert config.total_max_attempts == 5
assert config.retry_mode == "adaptive"
assert config.connect_timeout == 30.0
def test_partial_env_setup(self):
# Just setting one timeout var
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
config = BedrockBaseConfig()
assert config.connect_timeout == 120.0
assert config.read_timeout == 60.0 # still default
assert config.aws_access_key_id is None
def test_bad_max_attempts_breaks(self):
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
try:
BedrockBaseConfig()
raise AssertionError("Should have failed on bad int conversion")
except ValueError:
pass # expected

View file

@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching():

View file

@ -26,7 +26,6 @@ class TestProviderDataValidator(BaseModel):
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
super().__init__(
model_entries=[],
litellm_provider_name="test",
api_key_from_config=config.api_key,
provider_data_api_key_field="test_api_key",

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
@ -80,11 +80,22 @@ class TestOpenAIBaseURLConfig:
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
mock_model.id = "gpt-4"
# Create an async iterator that yields our mock model
async def mock_async_iterator():
yield mock_model
# Mock the AsyncOpenAI client and its models.list method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
mock_openai_class.return_value = mock_client
# Set the __provider_id__ attribute that's expected by list_models
adapter.__provider_id__ = "openai"
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
@ -94,8 +105,8 @@ class TestOpenAIBaseURLConfig:
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
# Verify the models.list method was called
mock_client.models.list.assert_called_once()
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
@ -110,11 +121,22 @@ class TestOpenAIBaseURLConfig:
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
mock_model.id = "gpt-4"
# Create an async iterator that yields our mock model
async def mock_async_iterator():
yield mock_model
# Mock the AsyncOpenAI client and its models.list method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
mock_openai_class.return_value = mock_client
# Set the __provider_id__ attribute that's expected by list_models
adapter.__provider_id__ = "openai"
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")

View file

@ -6,19 +6,15 @@
import asyncio
import json
import logging # allow-direct-logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoice,
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
@ -35,6 +31,9 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage,
ToolChoice,
ToolConfig,
@ -61,52 +60,21 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
response_body = json.dumps(response).encode("utf-8")
self.send_response(code=200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(response_body))
self.end_headers()
self.wfile.write(response_body)
self.request_handler = DelayedRequestHandler
def __enter__(self):
httpd = HTTPServer(("", 0), self.request_handler)
self.httpd = httpd
host, port = httpd.server_address
httpd_thread = threading.Thread(target=httpd.serve_forever)
httpd_thread.daemon = True # stop server if this thread terminates
httpd_thread.start()
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
inference_adapter = VLLMInferenceAdapter(config)
return inference_adapter
def __exit__(self, _exc_type, _exc_value, _traceback):
if self.httpd:
self.httpd.shutdown()
self.httpd.server_close()
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
@ -150,10 +118,16 @@ async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
@ -179,7 +153,7 @@ async def test_tool_call_response(vllm_inference_adapter):
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
@ -199,7 +173,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)]
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
@ -225,7 +199,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -250,7 +224,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -275,7 +249,9 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -299,7 +275,7 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -324,7 +300,7 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -349,7 +325,9 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -393,59 +371,6 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 200ms
loop.slow_callback_duration = 0.5
# Sleep for 500ms in our delayed http response
sleep_time = 0.5
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
mock_response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1,
"modle": "mock-model",
"choices": [
{
"message": {"content": ""},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
async def do_chat_completion():
await inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
inference_adapter.model_store = AsyncMock()
inference_adapter.model_store.get_model.return_value = mock_model
loop.run_until_complete(inference_adapter.initialize())
# Clear the logs so far and run the actual chat completion we care about
caplog.clear()
loop.run_until_complete(do_chat_completion())
# Ensure we don't have any asyncio warnings in the captured log
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
@ -638,33 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
This test verifies that the health method returns a HealthResponse with status OK
when the /health endpoint responds successfully.
"""
# Set vllm_inference_adapter.client to None to ensure _create_client is called
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Create a mock async iterator that yields a model when iterated
async def mock_list():
for model in [MagicMock()]:
yield model
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
# Verify that models.list was called
mock_models.list.assert_called_once()
# Verify that the health endpoint was called
mock_client_instance.get.assert_called_once()
call_args = mock_client_instance.get.call_args[0]
assert call_args[0].endswith("/health")
async def test_health_status_failure(vllm_inference_adapter):
@ -674,26 +595,190 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
# Create a mock async iterator that raises an exception when iterated
async def mock_list():
raise Exception("Connection failed")
yield # Unreachable code
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock client instance that raises an exception
mock_client_instance = MagicMock()
mock_client_instance.get.side_effect = Exception("Connection failed")
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]
mock_models.list.assert_called_once()
async def test_health_status_no_static_api_key(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when no static API key is provided.
This test verifies that the health method returns a HealthResponse with status OK
when the /health endpoint responds successfully, regardless of API token configuration.
"""
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
"""
Verify that openai_chat_completion is async and doesn't block the event loop.
To do this we mock the underlying inference with a sleep, start multiple
inference calls in parallel, and ensure the total time taken is less
than the sum of the individual sleep times.
"""
sleep_time = 0.5
async def mock_create(*args, **kwargs):
await asyncio.sleep(sleep_time)
return OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="nothing interesting",
),
finish_reason="stop",
index=0,
)
],
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=mock_create)
mock_create_client.return_value = mock_client
start_time = time.time()
await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference())
total_time = time.time() - start_time
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_should_refresh_models():
"""
Test the should_refresh_models method with different refresh_models configurations.
This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
"""
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -52,14 +52,19 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
self.evaluator_delete_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_delete"
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
self.mock_evaluator_delete = self.evaluator_delete_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
self.evaluator_delete_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
@ -115,6 +120,13 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_unregister_benchmark(self):
# Unregister the benchmark
self.run_async(self.eval_impl.unregister_benchmark(benchmark_id=MOCK_BENCHMARK_ID))
# Verify the Evaluator API was called correctly
self.mock_evaluator_delete.assert_called_once_with(f"/v1/evaluation/configs/nvidia/{MOCK_BENCHMARK_ID}")
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
@ -138,7 +150,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
"target": {"type": "model", "model": "Llama3.1-8B-Instruct"},
}
)

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.bedrock.bedrock import (
_get_region_prefix,
_to_inference_profile_id,
)
def test_region_prefixes():
assert _get_region_prefix("us-east-1") == "us."
assert _get_region_prefix("eu-west-1") == "eu."
assert _get_region_prefix("ap-south-1") == "ap."
assert _get_region_prefix("ca-central-1") == "us."
# Test case insensitive
assert _get_region_prefix("US-EAST-1") == "us."
assert _get_region_prefix("EU-WEST-1") == "eu."
assert _get_region_prefix("Ap-South-1") == "ap."
# Test None region
assert _get_region_prefix(None) == "us."
def test_model_id_conversion():
# Basic conversion
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
)
# Already has prefix
assert (
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
== "us.meta.llama3-1-70b-instruct-v1:0"
)
# ARN should be returned unchanged
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
assert _to_inference_profile_id(arn, "us-east-1") == arn
# ARN should be returned unchanged even without region
assert _to_inference_profile_id(arn) == arn
# Optional region parameter defaults to us-east-1
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
# Different regions work with optional parameter
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
)

View file

@ -0,0 +1,368 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class OpenAIMixinImpl(OpenAIMixin):
def __init__(self):
self.__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]]
return models
@pytest.fixture
def mock_client_with_models(mock_models):
"""Create a mock client with models.list() set up to return mock_models"""
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
return mock_client
@pytest.fixture
def mock_client_with_empty_models():
"""Create a mock client with models.list() set up to return empty list"""
mock_client = MagicMock()
async def mock_empty_models_list():
return
yield # Make it an async generator but don't yield anything
mock_client.models.list.return_value = mock_empty_models_list()
return mock_client
@pytest.fixture
def mock_client_with_exception():
"""Create a mock client with models.list() set up to raise an exception"""
mock_client = MagicMock()
mock_client.models.list.side_effect = Exception("API Error")
return mock_client
@pytest.fixture
def mock_client_context():
"""Fixture that provides a context manager for mocking the OpenAI client"""
def _mock_client_context(mixin, mock_client):
return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client)
return _mock_client_context
class TestOpenAIMixinListModels:
"""Test cases for the list_models method"""
async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context):
"""Test successful model listing"""
assert len(mixin._model_cache) == 0
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" in model_ids
for model in result:
assert model.provider_id == "test-provider"
assert model.model_type == ModelType.llm
assert model.provider_resource_id == model.identifier
assert len(mixin._model_cache) == 3
for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]:
assert model_id in mixin._model_cache
cached_model = mixin._model_cache[model_id]
assert cached_model.identifier == model_id
assert cached_model.provider_resource_id == model_id
async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context):
"""Test handling of empty model list"""
with mock_client_context(mixin, mock_client_with_empty_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
class TestOpenAIMixinCheckModelAvailability:
"""Test cases for the check_model_availability method"""
async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check when cache is populated"""
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
await mixin.list_models()
mock_client_with_models.models.list.assert_called_once()
assert await mixin.check_model_availability("some-mock-model-id")
assert await mixin.check_model_availability("another-mock-model-id")
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("non-existent-model")
mock_client_with_models.models.list.assert_called_once()
async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check when cache is empty (calls list_models)"""
assert len(mixin._model_cache) == 0
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("some-mock-model-id")
mock_client_with_models.models.list.assert_called_once()
assert len(mixin._model_cache) == 3
assert "some-mock-model-id" in mixin._model_cache
async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check for non-existent model"""
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert not await mixin.check_model_availability("non-existent-model")
mock_client_with_models.models.list.assert_called_once()
assert len(mixin._model_cache) == 3
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""
async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context):
"""Test that calling list_models overwrites existing cache"""
initial_model = Model(
provider_id="test-provider",
provider_resource_id="old-model",
identifier="old-model",
model_type=ModelType.llm,
)
mixin._model_cache = {"old-model": initial_model}
with mock_client_context(mixin, mock_client_with_models):
await mixin.list_models()
assert len(mixin._model_cache) == 3
assert "old-model" not in mixin._model_cache
assert "some-mock-model-id" in mixin._model_cache
assert "another-mock-model-id" in mixin._model_cache
assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == ""
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality"""
async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context):
"""Test that models in embedding_model_metadata are correctly identified as embeddings with metadata"""
# Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_embeddings, mock_client):
result = await mixin_with_embeddings.list_models()
assert result is not None
assert len(result) == 2
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
llm_model = next(m for m in result if m.identifier == "gpt-4")
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinAllowedModels:
"""Test cases for allowed_models filtering functionality"""
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test that list_models filters models based on allowed_models set"""
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" not in model_ids
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
"""Test that empty allowed_models set allows all models"""
assert len(mixin.allowed_models) == 0
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3 # All models should be included
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" in model_ids
async def test_check_model_availability_with_allowed_models(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability respects allowed_models"""
mixin.allowed_models = {"final-mock-model-id"}
with mock_client_context(mixin, mock_client_with_models):
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("some-mock-model-id")
assert not await mixin.check_model_availability("another-mock-model-id")

View file

@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -84,14 +84,14 @@ def unknown_model() -> Model:
@pytest.fixture
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
return ModelRegistryHelper([known_provider_model, known_provider_model2])
return ModelRegistryHelper(model_entries=[known_provider_model, known_provider_model2])
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
"""Test helper that simulates a provider with dynamically available models."""
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
super().__init__(model_entries)
super().__init__(model_entries=model_entries)
self._available_models = available_models
async def check_model_availability(self, model: str) -> bool:

View file

@ -54,7 +54,9 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
'{"identifier": "'
+ vector_db_id
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db

View file

@ -26,9 +26,9 @@ def test_generate_chunk_id():
chunk_ids = sorted([chunk.chunk_id for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
"31d1f9a3-c8d2-66e7-3c37-af2acd329778",
"d07dade7-29c0-cda7-df29-0249a1dcbc3e",
"d14f75a1-5855-7f72-2c78-d9fc4275a346",
]
@ -36,14 +36,14 @@ def test_generate_chunk_id_with_window():
chunk = Chunk(content="test", metadata={"document_id": "doc-1"})
chunk_id1 = generate_chunk_id("doc-1", chunk, chunk_window="0-1")
chunk_id2 = generate_chunk_id("doc-1", chunk, chunk_window="1-2")
assert chunk_id1 == "149018fe-d0eb-0f8d-5f7f-726bdd2aeedb"
assert chunk_id2 == "4562c1ee-9971-1f3b-51a6-7d05e5211154"
assert chunk_id1 == "8630321a-d9cb-2bb6-cd28-ebf68dafd866"
assert chunk_id2 == "13a1c09a-cbda-b61a-2d1a-7baa90888685"
def test_chunk_id():
# Test with existing chunk ID
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})
assert chunk_with_id.chunk_id == "84ededcc-b80b-a83e-1a20-ca6515a11350"
assert chunk_with_id.chunk_id == "11704f92-42b6-61df-bf85-6473e7708fbd"
# Test with document ID in metadata
chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"})