mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: Adding optional embeddings to content
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
97ccfb5e62
commit
aefbb6f9ea
20 changed files with 1314 additions and 132 deletions
|
|
@ -55,3 +55,65 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
|
|||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
||||
|
||||
async def test_update_vector_store_provider_id_change_fails():
|
||||
"""Test that updating a vector store with a different provider_id fails with clear error."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Try to update with different provider_id in metadata - this should fail
|
||||
with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"):
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::sqlite"}, # Different provider_id
|
||||
)
|
||||
|
||||
# Verify the existing store was looked up to check provider_id
|
||||
mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123")
|
||||
|
||||
# Provider should not be called since validation failed
|
||||
mock_routing_table.get_provider_impl.assert_not_called()
|
||||
|
||||
|
||||
async def test_update_vector_store_same_provider_id_succeeds():
|
||||
"""Test that updating a vector store with the same provider_id succeeds."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Update with same provider_id should succeed
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::faiss"}, # Same provider_id
|
||||
)
|
||||
|
||||
# Verify the provider update method was called
|
||||
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
|
||||
provider = await mock_routing_table.get_provider_impl("vs_123")
|
||||
provider.openai_update_vector_store.assert_called_once_with(
|
||||
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
|
||||
)
|
||||
|
|
|
|||
86
tests/unit/server/test_query_params_middleware.py
Normal file
86
tests/unit/server/test_query_params_middleware.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
||||
|
||||
|
||||
class TestQueryParamsMiddleware:
|
||||
"""Test cases for the QueryParamsMiddleware."""
|
||||
|
||||
async def test_extracts_query_params_for_vector_store_content(self):
|
||||
"""Test that middleware extracts query params for vector store content endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert hasattr(request.state, "extra_query")
|
||||
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
|
||||
|
||||
async def test_ignores_non_vector_store_endpoints(self):
|
||||
"""Test that middleware ignores non-vector store endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/inference/chat_completion"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert not hasattr(request.state, "extra_query")
|
||||
|
||||
async def test_handles_json_parsing(self):
|
||||
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
|
||||
assert request.state.extra_query == expected
|
||||
|
|
@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
|
|||
|
||||
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
|
||||
|
||||
# Mock minimal request
|
||||
# Mock minimal request with proper state object
|
||||
request = MagicMock()
|
||||
request.scope = {"user_attributes": {}, "principal": ""}
|
||||
request.headers = {}
|
||||
request.body = AsyncMock(return_value=b"")
|
||||
|
||||
# Create a simple state object without auto-generating attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
result = await route_handler(request)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue