From 0f5bef893a3c04879c95637cb0cbeb77430bbeb5 Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Thu, 18 Sep 2025 13:41:59 -0700 Subject: [PATCH] fix fireworks openai_embeddings --- .../providers/remote/inference/fireworks/fireworks.py | 11 ----------- tests/integration/inference/test_openai_embeddings.py | 3 ++- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 314e5c390..00f3f5418 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -26,7 +26,6 @@ from llama_stack.apis.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ResponseFormat, @@ -288,16 +287,6 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 622b97287..ce3d2a8ea 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id): provider = provider_from_model(client, model_id) if provider.provider_type in ( "remote::together", # service returns 400 + "remote::fireworks", # service returns 400 malformed input ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.") @@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): provider = provider_from_model(client, model_id) if provider.provider_type in ( "remote::together", # param silently ignored, always returns floats + "remote::fireworks", # param silently ignored, always returns list of floats ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.") @@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo input=input_texts, encoding_format="base64", ) - # Validate response structure assert response.object == "list" assert response.model == embedding_model_id