From dfbc61fb67a03e1c008906485eeae4dd4a71bcfb Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Wed, 17 Sep 2025 18:24:34 -0700 Subject: [PATCH] fix --- .../remote/inference/fireworks/fireworks.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 8f6338afc..314e5c390 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -130,13 +130,19 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) - r = await self._get_client().completions.create(**params) + r = await self._get_client().completion.acreate(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - stream = await self._get_client().completions.create(**params) + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() async for chunk in process_completion_stream_response(stream): yield chunk @@ -205,18 +211,23 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await self._get_client().chat.completions.create(**params) + r = await self._get_client().chat.completions.acreate(**params) else: - r = await self._get_client().completions.create(**params) + r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - if "messages" in params: - stream = await self._get_client().chat.completions.create(**params) - else: - stream = await self._get_client().completions.create(**params) + async def _to_async_generator(): + if "messages" in params: + stream = self._get_client().chat.completions.acreate(**params) + else: + stream = self._get_client().completion.acreate(**params) + async for chunk in stream: + yield chunk + + stream = _to_async_generator() async for chunk in process_chat_completion_stream_response(stream, request): yield chunk @@ -239,7 +250,8 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee # Fireworks always prepends with BOS if "prompt" in input_dict: - input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"]) + if input_dict["prompt"].startswith("<|begin_of_text|>"): + input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] params = { "model": request.model, @@ -267,7 +279,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee assert all(not content_has_media(content) for content in contents), ( "Fireworks does not support media for embeddings" ) - response = await self._get_client().embeddings.create( + response = self._get_client().embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], **kwargs,