mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
fix fireworks openai_embeddings
This commit is contained in:
parent
dfbc61fb67
commit
0f5bef893a
2 changed files with 2 additions and 12 deletions
|
@ -26,7 +26,6 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -288,16 +287,6 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # service returns 400
|
"remote::together", # service returns 400
|
||||||
|
"remote::fireworks", # service returns 400 malformed input
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # param silently ignored, always returns floats
|
"remote::together", # param silently ignored, always returns floats
|
||||||
|
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||||
|
|
||||||
|
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response structure
|
# Validate response structure
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
assert response.model == embedding_model_id
|
assert response.model == embedding_model_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue