mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
chore: Refactor to use OpenAIMixin
This commit is contained in:
parent
e3fd70c321
commit
e56a3f266c
1 changed files with 28 additions and 48 deletions
|
@ -7,9 +7,6 @@
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
@ -48,12 +45,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
@ -68,7 +65,7 @@ from .models import MODEL_ENTRIES
|
||||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -79,7 +76,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
return config_api_key
|
return config_api_key
|
||||||
|
@ -91,15 +88,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
)
|
)
|
||||||
return provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
def _get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return "https://api.fireworks.ai/inference/v1"
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||||
fireworks_api_key = self._get_api_key()
|
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
if prompt.startswith("<|begin_of_text|>"):
|
||||||
|
return prompt[len("<|begin_of_text|>") :]
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
return prompt
|
||||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -128,19 +124,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self.client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# Wrapper for async generator similar
|
stream = self.client.completions.create(**params)
|
||||||
async def _to_async_generator():
|
|
||||||
stream = self._get_client().completion.create(**params)
|
|
||||||
for chunk in stream:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
stream = _to_async_generator()
|
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -209,23 +199,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await self._get_client().chat.completions.acreate(**params)
|
r = await self.client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self.client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
stream = self._get_client().chat.completions.acreate(**params)
|
stream = self.client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
stream = self._get_client().completion.acreate(**params)
|
stream = self.client.completions.create(**params)
|
||||||
async for chunk in stream:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
stream = _to_async_generator()
|
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -248,8 +233,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if "prompt" in input_dict:
|
if "prompt" in input_dict:
|
||||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"])
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
@ -277,7 +261,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Fireworks does not support media for embeddings"
|
"Fireworks does not support media for embeddings"
|
||||||
)
|
)
|
||||||
response = self._get_client().embeddings.create(
|
response = await self.client.embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -319,14 +303,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
suffix: str | None = None,
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
if isinstance(prompt, str):
|
||||||
|
prompt = self._preprocess_prompt_for_fireworks(prompt)
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
return await super().openai_completion(
|
||||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
model=model,
|
||||||
prompt = prompt[len("<|begin_of_text|>") :]
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
|
@ -343,10 +324,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
guided_choice=guided_choice,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._get_openai_client().completions.create(**params)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -408,7 +390,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
return await super().openai_chat_completion(
|
||||||
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
|
@ -432,6 +415,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"fireworks params: {params}")
|
|
||||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue