mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 21:34:36 +00:00
Merge branch 'main' into use-openai-for-databricks
This commit is contained in:
commit
46ae101ca1
13 changed files with 815 additions and 1140 deletions
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,11 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
|
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
|
|||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
|
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
)
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_base_url(self) -> str:
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
logger.debug(f"fireworks params: {params}")
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue