diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 4d21010b9..38278d396 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -205,12 +205,10 @@ API responses, specify the adapter here. def remote_provider_spec( api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None ) -> RemoteProviderSpec: - if api_dependencies is None: - api_dependencies = [] return RemoteProviderSpec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, adapter=adapter, - api_dependencies=api_dependencies, + api_dependencies=api_dependencies or [], ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 43132da72..e7abde227 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -84,7 +84,7 @@ class MetaReferenceInferenceImpl( async def register_model(self, model: Model) -> Model: model = await self.model_registry_helper.register_model(model) if model.model_type == ModelType.embedding_model: - self._get_embedding_model(model.provider_resource_id) + self._load_sentence_transformer_model(model.provider_resource_id) return model async def completion( diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 18f4d8cd6..bd21698f1 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -48,7 +48,7 @@ class SentenceTransformersInferenceImpl( ) async def register_model(self, model: Model) -> None: - _ = self._get_embedding_model(model.provider_resource_id) + _ = self._load_sentence_transformer_model(model.provider_resource_id) return model async def unregister_model(self, model_id: str) -> None: @@ -63,7 +63,7 @@ class SentenceTransformersInferenceImpl( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncGenerator]: - raise NotImplementedError("Sentence transformers don't support completion") + raise ValueError("Sentence transformers don't support completion") async def chat_completion( self, @@ -77,4 +77,4 @@ class SentenceTransformersInferenceImpl( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError("Sentence transformers don't support chat completion") + raise ValueError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index b11cef22b..96cbcaa67 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -20,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client +from llama_stack.providers.utils.inference.prompt_adapter import content_has_media model_aliases = [ @@ -452,7 +454,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model = await self.model_store.get_model(model_id) embeddings = [] for content in contents: - input_text = str(content) if not isinstance(content, str) else content + assert not content_has_media( + content + ), "Bedrock does not support media for embeddings" + input_text = interleaved_text_media_as_str(content) input_body = {"inputText": input_text} body = json.dumps(input_body) response = self.client.invoke_model( diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e4836ba7c..b0e93305e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( @@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -105,9 +105,6 @@ class FireworksInferenceAdapter( fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) - def _get_openai_client(self) -> OpenAI: - return OpenAI(base_url=self.config.url, api_key=self._get_api_key()) - async def completion( self, model_id: str, @@ -272,12 +269,16 @@ class FireworksInferenceAdapter( ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - client = self._get_openai_client() kwargs = {} if model.metadata.get("embedding_dimensions"): kwargs["dimensions"] = model.metadata.get("embedding_dimensions") - response = client.embeddings.create( - model=model.provider_resource_id, input=contents, **kwargs + assert all( + not content_has_media(content) for content in contents + ), "Fireworks does not support media for embeddings" + response = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, ) embeddings = [data.embedding for data in response.data] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3e335e854..1ba4ad599 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_image_media_to_url, request_has_media, ) @@ -323,8 +324,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) + assert all( + not content_has_media(content) for content in contents + ), "Ollama does not support media for embeddings" response = await self.client.embed( - model=model.provider_resource_id, input=contents + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], ) embeddings = response["embeddings"] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 52079ef85..7cd798d16 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -254,8 +255,12 @@ class TogetherInferenceAdapter( contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) + assert all( + not content_has_media(content) for content in contents + ), "Together does not support media for embeddings" r = self._get_client().embeddings.create( - model=model.provider_resource_id, input=contents + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], ) embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ca3d24c6a..7ad5cef0f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -206,10 +207,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model = await self.model_store.get_model(model_id) kwargs = {} - if model.metadata.get("embedding_dimensions"): - kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert model.model_type == ModelType.embedding_model + assert model.metadata.get("embedding_dimensions") + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert all( + not content_has_media(content) for content in contents + ), "VLLM does not support media for embeddings" response = self.client.embeddings.create( - model=model.provider_resource_id, input=contents, **kwargs + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, ) embeddings = [data.embedding for data in response.data] diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index a195e4f3a..54ebcd83a 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -84,24 +84,3 @@ def pytest_generate_tests(metafunc): ): fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] metafunc.parametrize("inference_stack", fixtures, indirect=True) - - if "embedding_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--embedding-model") - if not model: - raise ValueError( - "No embedding model specified. Please provide a valid embedding model." - ) - params = [pytest.param(model, id="")] - - metafunc.parametrize("embedding_model", params, indirect=True) - - if "embedding_stack" in metafunc.fixturenames: - fixtures = INFERENCE_FIXTURES - if filtered_stacks := get_provider_fixture_overrides( - metafunc.config, - { - "inference": INFERENCE_FIXTURES, - }, - ): - fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] - metafunc.parametrize("embedding_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 20aa94cf0..e0d36124f 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -37,13 +37,6 @@ def inference_model(request): return request.config.getoption("--inference-model", None) -@pytest.fixture(scope="session") -def embedding_model(request): - if hasattr(request, "param"): - return request.param - return request.config.getoption("--embedding-model", None) - - @pytest.fixture(scope="session") def inference_remote() -> ProviderFixture: return remote_stack_fixture() @@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [ async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") - test_stack = await construct_stack_for_test( - [Api.inference], - {"inference": inference_fixture.providers}, - inference_fixture.provider_data, - models=[ModelInput(model_id=inference_model)], - ) + model_type = ModelType.llm + metadata = {} + if os.getenv("EMBEDDING_DIMENSION"): + model_type = ModelType.embedding_model + metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION") - return test_stack.impls[Api.inference], test_stack.impls[Api.models] - - -@pytest_asyncio.fixture(scope="session") -async def embedding_stack(request, embedding_model): - fixture_name = request.param - inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, models=[ ModelInput( - model_id=embedding_model, - model_type=ModelType.embedding_model, - metadata={ - "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), - }, + model_id=inference_model, + model_type=model_type, + metadata=metadata, ) ], ) diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py index 03aee0c28..3502c6b20 100644 --- a/llama_stack/providers/tests/inference/test_embeddings.py +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -14,15 +14,15 @@ from llama_stack.apis.inference import EmbeddingsResponse, ModelType class TestEmbeddings: @pytest.mark.asyncio - async def test_embeddings(self, embedding_model, embedding_stack): - inference_impl, models_impl = embedding_stack - model = await models_impl.get_model(embedding_model) + async def test_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) if model.model_type != ModelType.embedding_model: pytest.skip("This test is only applicable for embedding models") response = await inference_impl.embeddings( - model_id=embedding_model, + model_id=inference_model, contents=["Hello, world!"], ) assert isinstance(response, EmbeddingsResponse) @@ -35,9 +35,9 @@ class TestEmbeddings: ) @pytest.mark.asyncio - async def test_batch_embeddings(self, embedding_model, embedding_stack): - inference_impl, models_impl = embedding_stack - model = await models_impl.get_model(embedding_model) + async def test_batch_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) if model.model_type != ModelType.embedding_model: pytest.skip("This test is only applicable for embedding models") @@ -45,7 +45,7 @@ class TestEmbeddings: texts = ["Hello, world!", "This is a test", "Testing embeddings"] response = await inference_impl.embeddings( - model_id=embedding_model, + model_id=inference_model, contents=texts, ) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 141b02398..b53f8cd32 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -26,11 +26,13 @@ class SentenceTransformerEmbeddingMixin: contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._get_embedding_model(model.provider_resource_id) + embedding_model = self._load_sentence_transformer_model( + model.provider_resource_id + ) embeddings = embedding_model.encode(contents) return EmbeddingsResponse(embeddings=embeddings) - def _get_embedding_model(self, model: str) -> "SentenceTransformer": + def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model)