fix fireworks openai_embeddings

This commit is contained in:
Swapna Lekkala 2025-09-18 13:41:59 -07:00
parent dfbc61fb67
commit 0f5bef893a
2 changed files with 2 additions and 12 deletions

View file

@ -26,7 +26,6 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
@ -288,16 +287,6 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
embeddings = [data.embedding for data in response.data] embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,

View file

@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id) provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # service returns 400 "remote::together", # service returns 400
"remote::fireworks", # service returns 400 malformed input
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id) provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats "remote::together", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
input=input_texts, input=input_texts,
encoding_format="base64", encoding_format="base64",
) )
# Validate response structure # Validate response structure
assert response.object == "list" assert response.object == "list"
assert response.model == embedding_model_id assert response.model == embedding_model_id