mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 16:41:59 +00:00
Merge branch 'main' into opengauss-add
This commit is contained in:
commit
39e49ab97a
807 changed files with 79555 additions and 26772 deletions
176
tests/unit/providers/agent/test_get_raw_document_text.py
Normal file
176
tests/unit/providers/agent/test_get_raw_document_text.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Document
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_text_mime_types():
|
||||
"""Test that the function accepts text/* mime types."""
|
||||
document = Document(content="Sample text content", mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Sample text content"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_yaml_mime_type():
|
||||
"""Test that the function accepts application/yaml mime type."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
|
||||
"""Test that the function accepts text/yaml but emits a deprecation warning."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that exactly one warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
assert "application/yaml" in str(w[0].message)
|
||||
assert "deprecated" in str(w[0].message).lower()
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
|
||||
"""Test that text/yaml works with URL content and emits warning."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
|
||||
"""Test that text/yaml works with TextContentItem and emits warning."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unsupported_mime_types():
|
||||
"""Test that the function rejects unsupported mime types."""
|
||||
document = Document(
|
||||
content="Some content",
|
||||
mime_type="application/json", # Not supported
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document mime type: application/json"):
|
||||
await get_raw_document_text(document)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_url_content():
|
||||
"""Test that the function handles URL content correctly."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = "Content from URL"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = "Content from URL"
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Content from URL"
|
||||
mock_load.assert_called_once_with("https://example.com/test.txt")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_url():
|
||||
"""Test that the function handles YAML URLs correctly."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_text_content_item():
|
||||
"""Test that the function handles TextContentItem correctly."""
|
||||
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Text content item"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_text_content_item():
|
||||
"""Test that the function handles YAML TextContentItem correctly."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unexpected_content_type():
|
||||
"""Test that the function rejects unexpected document content types."""
|
||||
# Create a mock document that bypasses Pydantic validation
|
||||
mock_document = MagicMock(spec=Document)
|
||||
mock_document.mime_type = "text/plain"
|
||||
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
|
||||
await get_raw_document_text(mock_document)
|
||||
|
|
@ -8,7 +8,6 @@ from datetime import datetime
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
|
@ -50,7 +49,7 @@ def config(tmp_path):
|
|||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def agents_impl(config, mock_apis):
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
|
|
@ -117,7 +116,6 @@ def sample_agent_config():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent(agents_impl, sample_agent_config):
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
||||
|
|
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
|
|||
assert isinstance(agent_info.created_at, datetime)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent(agents_impl, sample_agent_config):
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
|
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
|
|||
assert isinstance(agent.created_at, datetime)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents(agents_impl, sample_agent_config):
|
||||
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
|
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
|
|||
assert agent2_response.agent_id in agent_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
|
|
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
|
|||
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
|
|
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
|
|||
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agent(agents_impl, sample_agent_config):
|
||||
# Create an agent
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.distribution.access_control.access_control import default_policy
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input."""
|
||||
# Setup
|
||||
|
|
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert result.output[0].content[0].text == "Dublin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||
# Setup
|
||||
|
|
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
assert result.output[1].content[0].annotations == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a tool call response that has a type of None."""
|
||||
# Setup
|
||||
|
|
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
assert chunks[1].response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with multiple messages."""
|
||||
# Setup
|
||||
|
|
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepend_previous_response_none(openai_responses_impl):
|
||||
"""Test prepending no previous response to a new response."""
|
||||
|
||||
|
|
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
|
|||
assert input == "fake_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a basic previous response to a new response."""
|
||||
|
||||
|
|
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
|||
assert input[2].content == "fake_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a web search previous response to a new response."""
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
|
|
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
assert input[3].content == "fake_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
|
||||
# Setup
|
||||
input_text = "What is the capital of Ireland?"
|
||||
|
|
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
assert sent_messages[1].content == input_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||
openai_responses_impl, mock_inference_api
|
||||
):
|
||||
|
|
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_instructions_and_previous_response(
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
|
|
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
||||
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
||||
# Setup
|
||||
|
|
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
|
|||
assert result.data[0].id == "msg_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_list_input_items_logic():
|
||||
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
|
||||
|
||||
|
|
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
|
|||
assert len(result.data) == 0 # Should return no items
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
|
|
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
|
|
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
|
|||
assert first_call.kwargs["response_format"] == response_format
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with an invalid text format."""
|
||||
# Setup
|
||||
|
|
|
|||
|
|
@ -9,21 +9,19 @@ from datetime import datetime
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
|
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
|
|||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
|
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
|||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
|
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
|||
await agent_persistence.get_session_turns(session_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
|
||||
|
||||
def test_groq_provider_openai_client_caching():
|
||||
"""Ensure the Groq provider does not cache api keys across client requests"""
|
||||
|
||||
config = GroqConfig()
|
||||
inference_adapter = GroqInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
openai_client = inference_adapter._get_openai_client()
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_openai_provider_openai_client_caching():
|
||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||
|
||||
config = OpenAIConfig()
|
||||
inference_adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
openai_client = inference_adapter.client
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_together_provider_openai_client_caching():
|
||||
"""Ensure the Together provider does not cache api keys across client requests"""
|
||||
|
||||
config = TogetherImplConfig()
|
||||
inference_adapter = TogetherInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
|
||||
together_client = inference_adapter._get_client()
|
||||
assert together_client.client.api_key == api_key
|
||||
openai_client = inference_adapter._get_openai_client()
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_llama_compat_provider_openai_client_caching():
|
||||
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
|
||||
config = LlamaCompatConfig()
|
||||
inference_adapter = LlamaCompatInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
112
tests/unit/providers/inference/test_litellm_openai_mixin.py
Normal file
112
tests/unit/providers/inference/test_litellm_openai_mixin.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
|
||||
# Test fixtures and helper classes
|
||||
class TestConfig(BaseModel):
|
||||
api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestProviderDataValidator(BaseModel):
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: TestConfig):
|
||||
super().__init__(
|
||||
model_entries=[],
|
||||
litellm_provider_name="test",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="test_api_key",
|
||||
openai_compat_api_base=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter_with_config_key():
|
||||
"""Fixture to create adapter with API key in config"""
|
||||
config = TestConfig(api_key="config-api-key")
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter_without_config_key():
|
||||
"""Fixture to create adapter without API key in config"""
|
||||
config = TestConfig(api_key=None)
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
def test_api_key_from_config_when_no_provider_data(adapter_with_config_key):
|
||||
"""Test that adapter uses config API key when no provider data is available"""
|
||||
api_key = adapter_with_config_key.get_api_key()
|
||||
assert api_key == "config-api-key"
|
||||
|
||||
|
||||
def test_provider_data_takes_priority_over_config(adapter_with_config_key):
|
||||
"""Test that provider data API key overrides config API key"""
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
||||
):
|
||||
api_key = adapter_with_config_key.get_api_key()
|
||||
assert api_key == "provider-data-key"
|
||||
|
||||
|
||||
def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key):
|
||||
"""Test fallback to config when provider data doesn't have the required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
api_key = adapter_with_config_key.get_api_key()
|
||||
assert api_key == "config-api-key"
|
||||
|
||||
|
||||
def test_error_when_no_api_key_available(adapter_without_config_key):
|
||||
"""Test that ValueError is raised when neither config nor provider data have API key"""
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
adapter_without_config_key.get_api_key()
|
||||
|
||||
|
||||
def test_error_when_provider_data_has_wrong_key(adapter_without_config_key):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
adapter_without_config_key.get_api_key()
|
||||
|
||||
|
||||
def test_provider_data_works_when_config_is_none(adapter_without_config_key):
|
||||
"""Test that provider data works even when config has no API key"""
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})}
|
||||
):
|
||||
api_key = adapter_without_config_key.get_api_key()
|
||||
assert api_key == "provider-only-key"
|
||||
|
||||
|
||||
def test_error_message_includes_correct_field_names(adapter_without_config_key):
|
||||
"""Test that error message includes correct field name and header information"""
|
||||
try:
|
||||
adapter_without_config_key.get_api_key()
|
||||
raise AssertionError("Should have raised ValueError")
|
||||
except ValueError as e:
|
||||
assert "test_api_key" in str(e) # Should mention the correct field name
|
||||
assert "x-llamastack-provider-data" in str(e) # Should mention header name
|
||||
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
|
||||
|
||||
class TestOpenAIBaseURLConfig:
|
||||
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
|
||||
|
||||
def test_default_base_url_without_env_var(self):
|
||||
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||
config = OpenAIConfig(api_key="test-key")
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||
|
||||
def test_custom_base_url_from_config(self):
|
||||
"""Test that the adapter uses a custom base URL when provided in config."""
|
||||
custom_url = "https://custom.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||
def test_base_url_from_environment_variable(self):
|
||||
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||
def test_config_overrides_environment_variable(self):
|
||||
"""Test that explicit config value overrides environment variable."""
|
||||
custom_url = "https://config.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Config should take precedence over environment variable
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
def test_client_uses_configured_base_url(self, mock_openai_class):
|
||||
"""Test that the OpenAI client is initialized with the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Access the client property to trigger AsyncOpenAI initialization
|
||||
_ = adapter.client
|
||||
|
||||
# Verify AsyncOpenAI was called with the correct base_url
|
||||
mock_openai_class.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
|
||||
"""Test that check_model_availability uses the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.retrieve method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
# Verify the client was created with the custom URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify the method was called and returned True
|
||||
mock_client.models.retrieve.assert_called_once_with("gpt-4")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
|
||||
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
# Verify the client was created with the environment variable URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://proxy.openai.com/v1",
|
||||
)
|
||||
|
|
@ -14,7 +14,6 @@ from typing import Any
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
|
@ -103,7 +102,7 @@ def mock_openai_models_list():
|
|||
yield mock_list
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
@pytest.fixture(scope="module")
|
||||
async def vllm_inference_adapter():
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
|
|
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
|
|||
return inference_adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
||||
async def mock_openai_models():
|
||||
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
||||
|
|
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
|
|||
mock_openai_models_list.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||
"""
|
||||
Test that we set tool_choice to none when no tools are in use
|
||||
|
|
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
assert request.tool_config.tool_choice == ToolChoice.none
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response(vllm_inference_adapter):
|
||||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||
into the expected JSON format."""
|
||||
|
|
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_delta_empty_tool_call_buf():
|
||||
"""
|
||||
Test that we don't generate extra chunks when processing a
|
||||
|
|
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
|
|||
assert chunks[1].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_delta_streaming_arguments_dict():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
|
|
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
|||
assert chunks[2].event.event_type.value == "complete"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
|
|
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
|
|||
assert chunks[3].event.event_type.value == "complete"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_vllm_chat_completion_stream_response_no_choices():
|
||||
"""
|
||||
Test that we don't error out when vLLM returns no choices for a
|
||||
|
|
@ -401,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
|||
assert chunks[0].event.event_type.value == "start"
|
||||
|
||||
|
||||
@pytest.mark.allow_network
|
||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_debug(True)
|
||||
|
|
@ -453,7 +446,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
|||
assert not asyncio_warnings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||
request = ChatCompletionRequest(
|
||||
tools=[],
|
||||
|
|
@ -464,7 +456,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
|
|||
assert "tools" not in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
|
||||
"""
|
||||
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
|
||||
|
|
@ -543,7 +534,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
|
|||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
||||
"""
|
||||
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
|
||||
|
|
@ -596,7 +586,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
|||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
||||
"""
|
||||
Tests the edge case where no arguments are provided for the tool call.
|
||||
|
|
@ -645,7 +634,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
|||
assert chunks[-2].event.delta.tool_call.arguments == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_status_success(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||
|
|
@ -679,7 +667,6 @@ async def test_health_status_success(vllm_inference_adapter):
|
|||
mock_models.list.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_status_failure(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection fails.
|
||||
|
|
|
|||
|
|
@ -5,103 +5,110 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||
|
||||
|
||||
class TestNvidiaDatastore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
self.adapter = NvidiaDatasetIOAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
adapter = NvidiaDatasetIOAdapter(config)
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
with patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
yield adapter, mock_make_request
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper function to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
|
||||
def test_register_dataset(nvidia_adapter, run_async):
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type=ResourceType.dataset,
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
run_async(adapter.register_dataset(dataset_def))
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
self.run_async(self.adapter.register_dataset(dataset_def))
|
||||
def test_unregister_dataset(nvidia_adapter, run_async):
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
run_async(adapter.unregister_dataset(dataset_id))
|
||||
|
||||
def test_unregister_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
|
||||
self.run_async(self.adapter.unregister_dataset(dataset_id))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
def test_register_dataset_with_custom_namespace_project(run_async):
|
||||
"""Test with custom namespace and project configuration."""
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
def test_register_dataset_with_custom_namespace_project(self):
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
|
||||
self.mock_make_request.return_value = {
|
||||
with patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
|
|
@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
type=ResourceType.dataset,
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
|
|
@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
metadata={"format": "jsonl"},
|
||||
)
|
||||
|
||||
self.run_async(custom_adapter.register_dataset(dataset_def))
|
||||
run_async(custom_adapter.register_dataset(dataset_def))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
|
|
@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
"format": "jsonl",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -20,21 +19,20 @@ from llama_stack.apis.post_training.post_training import (
|
|||
OptimizerType,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
class TestNvidiaParameters:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown(self):
|
||||
"""Setup and teardown for each test method."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
|
|
@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"updated_at": "2025-03-04T13:07:47.543605",
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
yield
|
||||
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
|
|
@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
job_uuid=required_job_uuid,
|
||||
model=required_model,
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
|
|
@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
data_config = DataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format=DatasetFormat.instruct,
|
||||
validation_dataset_id="val-dataset",
|
||||
|
|
@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config = OptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type=OptimizerType.adam,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = EfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
# Unsupported parameters
|
||||
efficiency_config=efficiency_config,
|
||||
max_steps_per_epoch=1000,
|
||||
gradient_accumulation_steps=4,
|
||||
|
|
@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
dtype="bf16",
|
||||
)
|
||||
|
||||
# Capture warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
|
|
@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
checkpoint_dir="test-dir",
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
|
|
@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
logger_config={"test": "value"},
|
||||
hyperparam_search_config={"test": "value"},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
|
|
@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -5,321 +5,353 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
"""Test implementation that provides the required shield_store."""
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||
super().__init__(config)
|
||||
self.shield_store = shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Set up the NVIDIASafetyAdapter for testing."""
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
# Initialize the adapter
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
# Create a mock shield store that implements the ShieldStore protocol
|
||||
shield_store = AsyncMock()
|
||||
shield_store.get_shield = AsyncMock()
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
return adapter
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
@pytest.fixture
|
||||
def mock_guardrails_post():
|
||||
"""Mock the HTTP request methods."""
|
||||
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
|
||||
mock_post.return_value = {"status": "allowed"}
|
||||
yield mock_post
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
def _assert_request(
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_register_shield_with_valid_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_register_shield_without_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
# Verify the Guardrails API was not called
|
||||
mock_guardrails_post.assert_not_called()
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(exc_info.value)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
def test_init_nemo_guardrails():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
|
|
@ -21,8 +19,7 @@ from llama_stack.apis.post_training.post_training import (
|
|||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
|
@ -32,331 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
)
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
@pytest.fixture
|
||||
def nvidia_post_training_adapter():
|
||||
"""Fixture to create and configure the NVIDIA post training adapter."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
with patch.object(adapter, "_make_request") as mock_make_request:
|
||||
yield adapter, mock_make_request
|
||||
|
||||
|
||||
def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
|
||||
async def test_supervised_fine_tune(nvidia_post_training_adapter):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_supervised_fine_tune(self):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
|
||||
"""Test that QAT configuration raises NotImplementedError."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with pytest.raises(NotImplementedError):
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
|
||||
async def test_get_training_job_status(nvidia_post_training_adapter):
|
||||
"""Test getting training job status with different statuses."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||
status = await adapter.get_training_job_status(job_uuid=job_id)
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
# Note: The response object inherits extra fields via ConfigDict(extra="allow")
|
||||
# So these attributes should be accessible using getattr with defaults
|
||||
assert getattr(status, "steps_completed", None) == 1210
|
||||
assert getattr(status, "epochs_completed", None) == 2
|
||||
assert getattr(status, "percentage_done", None) == 100.0
|
||||
assert getattr(status, "best_epoch", None) == 2
|
||||
assert getattr(status, "train_loss", None) == 1.718016266822815
|
||||
assert getattr(status, "val_loss", None) == 1.8661999702453613
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
def test_cancel_training_job(self):
|
||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
mock_make_request.reset_mock()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
async def test_get_training_jobs(nvidia_post_training_adapter):
|
||||
"""Test getting list of training jobs."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
jobs = await adapter.get_training_jobs()
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
|
||||
async def test_cancel_training_job(nvidia_post_training_adapter):
|
||||
"""Test canceling a training job."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = await adapter.cancel_training_job(job_uuid=job_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry, providable_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
class TestProviderConfigurations:
|
||||
|
|
|
|||
|
|
@ -5,13 +5,18 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
|
|
@ -23,7 +28,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
|
|
@ -33,7 +37,6 @@ async def test_convert_message_to_openai_dict():
|
|||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
|
|
@ -54,7 +57,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
|
|
@ -80,7 +82,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
|
|
@ -98,7 +99,6 @@ async def test_openai_messages_to_messages_with_content_str():
|
|||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
|
|
@ -114,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
|
|||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
|
|||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_url():
|
||||
"""Test extracting content from RAGDocument with URL content."""
|
||||
mock_url = URL(uri="https://example.com")
|
||||
|
|
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
|
|||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_pdf_url():
|
||||
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
|
||||
mock_url = URL(uri="https://example.com/document.pdf")
|
||||
|
|
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
|
|||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_data_url():
|
||||
"""Test extracting content from RAGDocument with data URL content."""
|
||||
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
|
||||
|
|
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
|
|||
mock_content_from_data.assert_called_once_with(data_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string():
|
||||
"""Test extracting content from RAGDocument with string content."""
|
||||
content_string = "This is plain text content"
|
||||
|
|
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
|
|||
assert result == content_string
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string_url():
|
||||
"""Test extracting content from RAGDocument with string URL content."""
|
||||
url_string = "https://example.com"
|
||||
|
|
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
|
|||
mock_instance.get.assert_called_once_with(url_string)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string_pdf_url():
|
||||
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
|
||||
url_string = "https://example.com/document.pdf"
|
||||
|
|
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
|
|||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_interleaved_content():
|
||||
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
|
||||
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
|
||||
|
|
|
|||
|
|
@ -87,18 +87,46 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
|
|||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
|
||||
"""Test helper that simulates a provider with dynamically available models."""
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
|
||||
super().__init__(model_entries)
|
||||
self._available_models = available_models
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
return model in self._available_models
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_model() -> Model:
|
||||
"""A model that's not in static config but available dynamically."""
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="dynamic-model",
|
||||
provider_resource_id="dynamic-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper_with_dynamic_models(
|
||||
known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model
|
||||
) -> MockModelRegistryHelperWithDynamicModels:
|
||||
"""Helper that includes dynamically available models."""
|
||||
return MockModelRegistryHelperWithDynamicModels(
|
||||
[known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id]
|
||||
)
|
||||
|
||||
|
||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(unknown_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
|
|
@ -110,7 +138,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
|
|||
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
|
|
@ -122,13 +149,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
|
|||
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model)
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing_different(
|
||||
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||
) -> None:
|
||||
|
|
@ -137,27 +162,86 @@ async def test_register_model_existing_different(
|
|||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model) # duplicate entry
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.model_id)
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
# await helper.register_model(known_model) # duplicate entry
|
||||
# assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
# await helper.unregister_model(known_model.model_id)
|
||||
# assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.unregister_model(unknown_model.model_id)
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
# with pytest.raises(ValueError):
|
||||
# await helper.unregister_model(unknown_model.model_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.provider_resource_id)
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
# assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
# await helper.unregister_model(known_model.provider_resource_id)
|
||||
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
|
||||
|
||||
async def test_register_model_from_check_model_availability(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||
) -> None:
|
||||
"""Test that models returned by check_model_availability can be registered."""
|
||||
# Verify the model is not in static config
|
||||
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
|
||||
|
||||
# But it should be available via check_model_availability
|
||||
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
|
||||
assert is_available
|
||||
|
||||
# Registration should succeed
|
||||
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
|
||||
assert registered_model == dynamic_model
|
||||
|
||||
# Model should now be registered and accessible
|
||||
assert (
|
||||
helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id
|
||||
)
|
||||
|
||||
|
||||
async def test_register_model_not_in_static_or_dynamic(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
|
||||
) -> None:
|
||||
"""Test that models not in static config or dynamic models are rejected."""
|
||||
# Verify the model is not in static config
|
||||
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
|
||||
|
||||
# And not available via check_model_availability
|
||||
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
|
||||
assert not is_available
|
||||
|
||||
# Registration should fail with comprehensive error message
|
||||
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
|
||||
await helper_with_dynamic_models.register_model(unknown_model)
|
||||
|
||||
# Error should include static models and "..." for dynamic models
|
||||
error_str = str(exc_info.value)
|
||||
assert "..." in error_str # "..." should be in error message
|
||||
|
||||
|
||||
async def test_register_alias_for_dynamic_model(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||
) -> None:
|
||||
"""Test that we can register an alias that maps to a dynamically available model."""
|
||||
# Create a model with a different identifier but same provider_resource_id
|
||||
alias_model = Model(
|
||||
provider_id=dynamic_model.provider_id,
|
||||
identifier="dynamic-model-alias",
|
||||
provider_resource_id=dynamic_model.provider_resource_id,
|
||||
)
|
||||
|
||||
# Registration should succeed since the provider_resource_id is available dynamically
|
||||
registered_model = await helper_with_dynamic_models.register_model(alias_model)
|
||||
assert registered_model == alias_model
|
||||
|
||||
# Both the original provider_resource_id and the new alias should work
|
||||
assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import pytest
|
|||
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_unknown_backend():
|
||||
with pytest.raises(ValueError):
|
||||
Scheduler(backend="unknown")
|
||||
|
|
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
|
|||
raise TimeoutError(f"Job {job_id} did not complete in time.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive():
|
||||
sched = Scheduler()
|
||||
|
||||
|
|
@ -87,7 +85,6 @@ async def test_scheduler_naive():
|
|||
assert job.logs[0][0] < job.logs[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive_handler_raises():
|
||||
sched = Scheduler()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,20 +8,32 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
|
@ -90,11 +102,6 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
|||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_inference_api(embedding_dimension):
|
||||
class MockInferenceAPI:
|
||||
|
|
@ -116,7 +123,7 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def sqlite_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
|
||||
return db_path
|
||||
|
||||
|
||||
|
|
@ -198,13 +205,145 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def faiss_vec_index(embedding_dimension):
|
||||
index = FaissIndex(embedding_dimension)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = FaissVectorIOConfig(
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_vec_db_path(tmp_path_factory):
|
||||
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
||||
return str(persist_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||
client = PersistentClient(path=chroma_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
collection = await maybe_await(client.get_or_create_collection(name))
|
||||
index = ChromaIndex(client=client, collection=collection)
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_vec_db_path(tmp_path_factory):
|
||||
import uuid
|
||||
|
||||
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
|
||||
|
||||
client = AsyncQdrantClient(path=qdrant_vec_db_path)
|
||||
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
index = QdrantIndex(client, collection_name)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
if vector_provider == "milvus":
|
||||
return request.getfixturevalue("milvus_vec_adapter")
|
||||
else:
|
||||
return request.getfixturevalue("sqlite_vec_adapter")
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
326
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
326
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
yield index
|
||||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Verify collection was created and data was inserted
|
||||
mock_milvus_client.create_collection.assert_called_once()
|
||||
mock_milvus_client.insert.assert_called_once()
|
||||
|
||||
# Verify the insert call had the right number of chunks
|
||||
insert_call = mock_milvus_client.insert.call_args
|
||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
query_string = "Sentence 5"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||
|
||||
# Mock simple text search results
|
||||
mock_milvus_client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||
},
|
||||
]
|
||||
|
||||
# Test keyword search that should fall back to simple text search
|
||||
query_string = "Python"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||
|
||||
# Verify that simple text search was used (query method called instead of search)
|
||||
mock_milvus_client.query.assert_called_once()
|
||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||
|
||||
# Verify the query uses parameterized filter with filter_params
|
||||
query_call_args = mock_milvus_client.query.call_args
|
||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||
|
||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||
|
||||
|
||||
async def test_query_hybrid_search_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with RRF reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
|
||||
# Check that the request contains both vector and BM25 search requests
|
||||
reqs = call_args[1]["reqs"]
|
||||
assert len(reqs) == 2
|
||||
assert reqs[0].anns_field == "vector"
|
||||
assert reqs[1].anns_field == "sparse"
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_weighted(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with weighted reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_default_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with default reranker (should be RRF)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="unknown_type", # Should default to RRF
|
||||
reranker_params=None, # Should use default impact_factor
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
|
||||
# Verify hybrid search was called with RRF reranker
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
|
@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
|
|
@ -91,13 +90,13 @@ def faiss_config():
|
|||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def faiss_index(embedding_dimension):
|
||||
index = await FaissIndex.create(dimension=embedding_dimension)
|
||||
yield index
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
|
||||
# Create the adapter
|
||||
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
|
||||
|
|
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
|
|||
yield adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
|
|
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_success():
|
||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
|
|
@ -160,7 +157,6 @@ async def test_health_success():
|
|||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_failure():
|
||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Any
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.providers.inline.vector_io.qdrant.config import (
|
|||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
||||
QdrantVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
@ -37,7 +37,8 @@ from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
|||
|
||||
@pytest.fixture
|
||||
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
|
||||
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -51,6 +52,9 @@ def mock_vector_db(vector_db_id) -> MagicMock:
|
|||
mock_vector_db.embedding_model = "embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = 384
|
||||
mock_vector_db.model_dump_json.return_value = (
|
||||
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
)
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
|
|
@ -68,9 +72,9 @@ def mock_api_service(sample_embeddings):
|
|||
return mock_api_service
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
|
|
@ -80,7 +84,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
|
|||
__QUERY = "Sample query"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
|
|
@ -111,7 +114,6 @@ def _prepare_for_json(value: Any) -> str:
|
|||
|
||||
|
||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_register_and_unregister_vector_db(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
mock_vector_db,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import asyncio
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
|
|
@ -34,23 +33,21 @@ def loop():
|
|||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
|
||||
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
|
||||
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_string = "Sentence 5"
|
||||
|
|
@ -68,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
|
|||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
|
|
@ -90,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
|
|||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
# Re-initialize with a clean index
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -103,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
|
|||
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
|
|
@ -116,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
|||
cur = connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
|
@ -134,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -163,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
|
|||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search with a high score threshold."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -185,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
|
|||
assert len(response.chunks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_different_embedding(
|
||||
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
|
|
@ -211,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
|
|||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -236,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
|||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
|
|
@ -284,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
|
|||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search with documents that appear in only one search method."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -313,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
|||
assert "document-2" in doc_ids # From keyword search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||
):
|
||||
|
|
@ -369,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test RRFReRanker with different impact factors."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -401,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
|
|||
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
|
|
@ -445,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
|
|||
assert len(response.chunks) <= 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_tie_breaking(
|
||||
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||
):
|
||||
|
|
|
|||
|
|
@ -25,12 +25,10 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
|
|||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_index(vector_index):
|
||||
await vector_index.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||
vector_index.delete()
|
||||
vector_index.initialize()
|
||||
|
|
@ -40,7 +38,6 @@ async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embed
|
|||
vector_index.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||
await vector_index.add_chunks(sample_chunks, embeddings)
|
||||
|
|
@ -54,7 +51,6 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
|||
assert len(contents) == len(set(contents))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
|
|
@ -65,7 +61,6 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
|||
await vector_io_adapter.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
|
|
@ -79,7 +74,6 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
|
|
@ -92,17 +86,19 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
|
|||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_unregistered_raises(vector_io_adapter):
|
||||
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
if vector_provider == "chroma":
|
||||
with pytest.raises(AttributeError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||
fake_index = AsyncMock()
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
await vector_io_adapter.insert_chunks("db1", chunks)
|
||||
|
|
@ -110,7 +106,6 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
|||
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
|
|
@ -118,11 +113,10 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
|||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||
|
||||
|
|
@ -130,7 +124,6 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
|||
assert response is expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
|
|
@ -138,7 +131,6 @@ async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
|||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
|
|
@ -155,7 +147,6 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
|||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
|
|
@ -172,7 +163,6 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
|||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
|
|
@ -188,7 +178,6 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
|||
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_stores(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
|
|
@ -204,7 +193,6 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
|||
assert loaded_stores[store_id] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
|
@ -226,7 +214,6 @@ async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory
|
|||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
|
@ -260,7 +247,6 @@ async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_facto
|
|||
assert loaded_contents != file_info
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
|
@ -284,7 +270,6 @@ async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_pat
|
|||
assert loaded_contents == file_contents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
|
@ -305,5 +290,7 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
|
|||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||
|
||||
loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||
assert loaded_file_info == {}
|
||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == []
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
# This test is a unit test for the chunk_utils.py helpers. This should only contain
|
||||
# tests which are specific to this file. More general (API-level) tests should be placed in
|
||||
Loading…
Add table
Add a link
Reference in a new issue