mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore: Refactor fireworks to use OpenAIMixin (#3480)
Some checks failed
Python Package Build Test / build (3.12) (push) Failing after 2s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 2s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 38s
Pre-commit / pre-commit (push) Successful in 1m17s
Some checks failed
Python Package Build Test / build (3.12) (push) Failing after 2s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 2s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 38s
Pre-commit / pre-commit (push) Successful in 1m17s
# What does this PR do? Refactor Fireworks to use OpenAIMixin Closes https://github.com/llamastack/llama-stack/issues/3391 Related to https://github.com/llamastack/llama-stack/issues/3387 ## Test Plan ``` (llama-stack) (base) swapna942@swapna942-mac llama-stack % FIREWORKS_API_KEY=**** ./scripts/integration-tests.sh --stack-config server:ci-tests --setup fireworks --subdirs inference --pattern openai tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] instantiating llama_stack_client Port 8321 is already in use, assuming server is already running... llama_stack_client instantiated in 0.031s PASSED [ 2%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 4%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 6%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 8%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 10%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 12%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 14%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 17%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 19%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 21%] tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:sanity] PASSED [ 23%] tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming_suffix[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:suffix] SKIPPED [ 25%] tests/integration/inference/test_openai_completion.py::test_openai_completion_streaming[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:completion:sanity] PASSED [ 27%] tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-1] SKIPPED [ 29%] tests/integration/inference/test_openai_completion.py::test_openai_completion_guided_choice[txt=accounts/fireworks/models/llama-v3p1-8b-instruct] SKIPPED [ 31%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_01] PASSED [ 34%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 36%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 38%] tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 40%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 42%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming_with_file[txt=accounts/fireworks/models/llama-v3p1-8b-instruct] SKIPPED [ 44%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 46%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 48%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 51%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 53%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 55%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 57%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 59%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] PASSED [ 61%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 63%] tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] SKIPPED [ 65%] tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=accounts/fireworks/models/llama-v3p1-8b-instruct-0] SKIPPED [ 68%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] PASSED [ 70%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 72%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 74%] tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 76%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 78%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_01] PASSED [ 80%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 82%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_01] PASSED [ 85%] tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 87%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] PASSED [ 89%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] PASSED [ 91%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 93%] tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:streaming_02] PASSED [ 95%] tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [ 97%] tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] PASSED [100%] ========================================== slowest 10 durations ========================================== 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-False] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-True] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=accounts/fireworks/models/llama-v3p1-8b-instruct-inference:chat_completion:non_streaming_02] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[llama_stack_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] 30.01s teardown tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[openai_client-emb=nomic-ai/nomic-embed-text-v1.5] ================= 36 passed, 11 skipped, 50 deselected, 4 warnings in 1429.05s (0:23:49) ================= + exit_code=0 + set +x ✅ All tests completed successfully ```
This commit is contained in:
parent
e3fd70c321
commit
8d8261961e
3 changed files with 22 additions and 168 deletions
|
@ -4,11 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
|
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
|
|||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
|
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
)
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_base_url(self) -> str:
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
logger.debug(f"fireworks params: {params}")
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
|||
provider = provider_from_model(client, model_id)
|
||||
if provider.provider_type in (
|
||||
"remote::together", # service returns 400
|
||||
"remote::fireworks", # service returns 400 malformed input
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
||||
|
||||
|
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
|||
provider = provider_from_model(client, model_id)
|
||||
if provider.provider_type in (
|
||||
"remote::together", # param silently ignored, always returns floats
|
||||
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||
|
||||
|
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
|||
input=input_texts,
|
||||
encoding_format="base64",
|
||||
)
|
||||
|
||||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
assert response.model == embedding_model_id
|
||||
|
|
|
@ -108,6 +108,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
|||
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
|
||||
},
|
||||
),
|
||||
"fireworks": Setup(
|
||||
name="fireworks",
|
||||
description="Fireworks provider with a text model",
|
||||
defaults={
|
||||
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue