revert client change to non openai methods

This commit is contained in:
Swapna Lekkala 2025-09-17 18:19:15 -07:00
parent d60514b57b
commit 9aee115abe

View file

@ -7,6 +7,8 @@
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -91,6 +93,10 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
def get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str: def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it""" """Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"): if prompt.startswith("<|begin_of_text|>"):
@ -124,13 +130,13 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.completions.create(**params) r = await self._get_client().completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
stream = await self.client.completions.create(**params) stream = await self._get_client().completions.create(**params)
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
@ -199,18 +205,18 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self.client.chat.completions.create(**params) r = await self._get_client().chat.completions.create(**params)
else: else:
r = await self.client.completions.create(**params) r = await self._get_client().completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
stream = await self.client.chat.completions.create(**params) stream = await self._get_client().chat.completions.create(**params)
else: else:
stream = await self.client.completions.create(**params) stream = await self._get_client().completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
@ -261,7 +267,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings" "Fireworks does not support media for embeddings"
) )
response = await self.client.embeddings.create( response = await self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
**kwargs, **kwargs,