mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
Merge branch 'main' into fix/embedding-model-type
This commit is contained in:
commit
309f06829c
59 changed files with 1005 additions and 339 deletions
|
@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
|
|||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_create_vector_store(self, **kwargs):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||
|
||||
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=kwargs.get("name", vector_store_id),
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
|
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
assert vdb1.identifier in vector_db_ids
|
||||
assert vdb2.identifier in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
# Verify they have UUID-based identifiers
|
||||
assert vdb1.identifier.startswith("vs_")
|
||||
assert vdb2.identifier.startswith("vs_")
|
||||
|
||||
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
|
|||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
self.vector_stores = {}
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
|
|||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name=None,
|
||||
embedding_model=None,
|
||||
embedding_dimension=None,
|
||||
provider_id=None,
|
||||
provider_vector_db_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_store = VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=name or vector_store_id,
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
self.vector_stores[vector_store_id] = vector_store
|
||||
return vector_store
|
||||
|
||||
async def openai_list_vector_stores(self, **kwargs):
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||
)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
n = 10
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
|
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
assert len(vector_dbs.data) == len(vdb_dict)
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
for k in vdb_dict:
|
||||
assert vdb_dict[k].identifier in vector_db_ids
|
||||
for k in vdb_dict:
|
||||
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||
n = 10
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
vector_store_ids = {v.id for v in vector_stores.data}
|
||||
|
||||
assert vector_db_ids == vector_store_ids, (
|
||||
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||
)
|
||||
|
||||
for vector_store in vector_stores.data:
|
||||
vector_db = await table.get_vector_db(vector_store.id)
|
||||
assert vector_store.name == vector_db.vector_db_name, (
|
||||
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||
)
|
||||
|
||||
for vector_db_id in vector_db_ids:
|
||||
await table.unregister_vector_db(vector_db_id)
|
||||
|
||||
assert len((await table.list_vector_dbs()).data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
user_provided_id = "my-custom-vector-db"
|
||||
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
assert len(vector_stores.data) == 1
|
||||
|
||||
vector_store = vector_stores.data[0]
|
||||
|
||||
assert vector_store.name == user_provided_id
|
||||
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert vector_store.id != user_provided_id
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 1
|
||||
assert vector_dbs.data[0].identifier == vector_store.id
|
||||
|
||||
await table.unregister_vector_db(vector_store.id)
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
|
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
|
|||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
|
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
|
|||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
|
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
|||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
|||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
|||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
async def test_validate_input_missing_parameters_chat_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
|||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {"model": "test-model", "prompt": "Hello"},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
|
|
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
class TestBedrockBaseConfig:
|
||||
def test_defaults_work_without_env_vars(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
# Basic creds should be None
|
||||
assert config.aws_access_key_id is None
|
||||
assert config.aws_secret_access_key is None
|
||||
assert config.region_name is None
|
||||
|
||||
# Timeouts get defaults
|
||||
assert config.connect_timeout == 60.0
|
||||
assert config.read_timeout == 60.0
|
||||
assert config.session_ttl == 3600
|
||||
|
||||
def test_env_vars_get_picked_up(self):
|
||||
env_vars = {
|
||||
"AWS_ACCESS_KEY_ID": "AKIATEST123",
|
||||
"AWS_SECRET_ACCESS_KEY": "secret123",
|
||||
"AWS_DEFAULT_REGION": "us-west-2",
|
||||
"AWS_MAX_ATTEMPTS": "5",
|
||||
"AWS_RETRY_MODE": "adaptive",
|
||||
"AWS_CONNECT_TIMEOUT": "30",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
assert config.aws_access_key_id == "AKIATEST123"
|
||||
assert config.aws_secret_access_key == "secret123"
|
||||
assert config.region_name == "us-west-2"
|
||||
assert config.total_max_attempts == 5
|
||||
assert config.retry_mode == "adaptive"
|
||||
assert config.connect_timeout == 30.0
|
||||
|
||||
def test_partial_env_setup(self):
|
||||
# Just setting one timeout var
|
||||
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
assert config.connect_timeout == 120.0
|
||||
assert config.read_timeout == 60.0 # still default
|
||||
assert config.aws_access_key_id is None
|
||||
|
||||
def test_bad_max_attempts_breaks(self):
|
||||
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
|
||||
try:
|
||||
BedrockBaseConfig()
|
||||
raise AssertionError("Should have failed on bad int conversion")
|
||||
except ValueError:
|
||||
pass # expected
|
|
@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
|||
|
||||
class TestRagQuery:
|
||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||
|
||||
async def test_query_chunk_metadata_handling(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
content = "test query content"
|
||||
vector_db_ids = ["db1"]
|
||||
|
||||
|
|
|
@ -113,6 +113,15 @@ class TestTranslateException:
|
|||
assert result.status_code == 504
|
||||
assert result.detail == "Operation timed out: "
|
||||
|
||||
def test_translate_connection_error(self):
|
||||
"""Test that ConnectionError is translated to 502 HTTP status."""
|
||||
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
|
||||
result = translate_exception(exc)
|
||||
|
||||
assert isinstance(result, HTTPException)
|
||||
assert result.status_code == 502
|
||||
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
|
||||
|
||||
def test_translate_not_implemented_error(self):
|
||||
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
||||
exc = NotImplementedError("Not implemented")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue