diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 639e5ee73..5790c498b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -230,12 +230,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if Api.telemetry in self.impls: setup_logger(self.impls[Api.telemetry]) - console = Console() - console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") - - # Redact sensitive information before printing - safe_config = redact_sensitive_fields(self.config.model_dump()) - console.print(yaml.dump(safe_config, indent=2)) + if not os.environ.get("PYTEST_CURRENT_TEST"): + console = Console() + console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") + safe_config = redact_sensitive_fields(self.config.model_dump()) + console.print(yaml.dump(safe_config, indent=2)) endpoints = get_all_api_endpoints() endpoint_impls = {} diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32aa5da3f..ac421475f 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -14,6 +14,7 @@ from llama_stack.apis.inference import ( ModelStore, TextTruncation, ) +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str EMBEDDING_MODELS = {} @@ -34,7 +35,7 @@ class SentenceTransformerEmbeddingMixin: ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode(contents) + embeddings = embedding_model.encode([interleaved_content_as_str(content) for content in contents]) return EmbeddingsResponse(embeddings=embeddings) def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": diff --git a/tests/client-sdk/vector_io/conftest.py b/tests/client-sdk/vector_io/conftest.py deleted file mode 100644 index 64cac27d2..000000000 --- a/tests/client-sdk/vector_io/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -def pytest_addoption(parser): - parser.addoption( - "--embedding-model", - action="store", - default="all-MiniLM-L6-v2", - help="Specify the embedding model to use for testing", - ) - - -def pytest_generate_tests(metafunc): - if "embedding_model" in metafunc.fixturenames: - metafunc.parametrize( - "embedding_model", - [metafunc.config.getoption("--embedding-model")], - ) diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index c7e4040b6..e093548b5 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -36,12 +36,12 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) -def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): +def test_vector_db_retrieve(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id): # Register a memory bank first vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model=embedding_model, + embedding_model=embedding_model_id, embedding_dimension=384, provider_id=provider_id, ) @@ -50,7 +50,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db response = llama_stack_client.vector_dbs.retrieve(vector_db_id=vector_db_id) assert response is not None assert response.identifier == vector_db_id - assert response.embedding_model == embedding_model + assert response.embedding_model == embedding_model_id assert response.provider_id == provider_id assert response.provider_resource_id == vector_db_id @@ -61,11 +61,11 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry): @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) -def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): +def test_vector_db_register(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model=embedding_model, + embedding_model=embedding_model_id, embedding_dimension=384, provider_id=provider_id, )