mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
revert client change to non openai methods
This commit is contained in:
parent
d60514b57b
commit
9aee115abe
1 changed files with 13 additions and 7 deletions
|
@ -7,6 +7,8 @@
|
|||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -91,6 +93,10 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
|
@ -124,13 +130,13 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._get_client().completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._get_client().completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
|
@ -199,18 +205,18 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.client.chat.completions.create(**params)
|
||||
r = await self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
if "messages" in params:
|
||||
stream = await self.client.chat.completions.create(**params)
|
||||
stream = await self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._get_client().completions.create(**params)
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
|
@ -261,7 +267,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Fireworks does not support media for embeddings"
|
||||
)
|
||||
response = await self.client.embeddings.create(
|
||||
response = await self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue