diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 48c163c87..58678a9cc 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) + + # Divert Llama Models through Llama Stack inference APIs because + # Fireworks chat completions OpenAI-compatible API does not support + # tool calls properly. + llama_model = self.get_llama_model(model_obj.provider_resource_id) + if llama_model: + return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( + self, + model=model, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + params = await prepare_openai_completion_params( messages=messages, frequency_penalty=frequency_penalty, @@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) - # Divert Llama Models through Llama Stack inference APIs because - # Fireworks chat completions OpenAI-compatible API does not support - # tool calls properly. - llama_model = self.get_llama_model(model_obj.provider_resource_id) - if llama_model: - return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params) - return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)