From 602e949a4612423ec96f2dde0b5bc627cee45fbe Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 21 Apr 2025 14:49:12 -0400 Subject: [PATCH] fix: OpenAI Completions API and Fireworks (#1997) # What does this PR do? We were passing a dict into the compat mixin for OpenAI Completions when using Llama models with Fireworks, and that was breaking some strong typing code that was added in openai_compat.py. We shouldn't have been converting these params to a dict in that case anyway, so this adjusts things to pass the params in as their actual original types when calling the OpenAIChatCompletionToLlamaStackMixin. ## Test Plan All of the fireworks provider verification tests were failing due to some OpenAI compatibility cleanup in #1962. The changes in that PR were good to make, and this just cleans up the fireworks provider code to stop passing in untyped dicts to some of those `openai_compat.py` methods since we have the original strongly-typed parameters we can pass in. ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` ``` python -m pytest -s -v tests/verifications/openai_api/test_chat_completion.py --provider=fireworks-llama-stack ``` Before this PR, all of the fireworks OpenAI verification tests were failing. Now, most of them are passing. Signed-off-by: Ben Browning --- .../remote/inference/fireworks/fireworks.py | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 48c163c87..58678a9cc 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) + + # Divert Llama Models through Llama Stack inference APIs because + # Fireworks chat completions OpenAI-compatible API does not support + # tool calls properly. + llama_model = self.get_llama_model(model_obj.provider_resource_id) + if llama_model: + return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( + self, + model=model, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + params = await prepare_openai_completion_params( messages=messages, frequency_penalty=frequency_penalty, @@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) - # Divert Llama Models through Llama Stack inference APIs because - # Fireworks chat completions OpenAI-compatible API does not support - # tool calls properly. - llama_model = self.get_llama_model(model_obj.provider_resource_id) - if llama_model: - return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params) - return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)