This commit is contained in:
Swapna Lekkala 2025-09-17 18:24:34 -07:00
parent 9aee115abe
commit dfbc61fb67

View file

@ -130,13 +130,19 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completions.create(**params)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._get_client().completions.create(**params)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -205,18 +211,23 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.create(**params)
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completions.create(**params)
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = await self._get_client().chat.completions.create(**params)
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = await self._get_client().completions.create(**params)
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@ -239,7 +250,8 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
# Fireworks always prepends with BOS
if "prompt" in input_dict:
input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"])
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
@ -267,7 +279,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = await self._get_client().embeddings.create(
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,