mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(fireworks_ai/): inherit from openai like base config (#7146)
* refactor(fireworks_ai/): inherit from openai like base config refactors fireworks ai to use a common config * test: fix import in test * refactor(watsonx/): refactor watsonx to use llm base config refactors chat + completion routes to base config path * fix: fix linting error * test: fix test * fix: fix test
This commit is contained in:
parent
6a9225fac2
commit
4eeaaeeacd
15 changed files with 449 additions and 307 deletions
|
@ -1163,10 +1163,11 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.sambanova.chat import SambanovaConfig
|
from .llms.sambanova.chat import SambanovaConfig
|
||||||
from .llms.ai21.chat import AI21ChatConfig
|
from .llms.ai21.chat import AI21ChatConfig
|
||||||
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||||
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||||
FireworksAIEmbeddingConfig,
|
FireworksAIEmbeddingConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
|
||||||
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
||||||
from .llms.xai.chat.transformation import XAIChatConfig
|
from .llms.xai.chat.transformation import XAIChatConfig
|
||||||
from .llms.volcengine import VolcEngineConfig
|
from .llms.volcengine import VolcEngineConfig
|
||||||
|
@ -1183,7 +1184,7 @@ from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||||
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
||||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
||||||
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
|
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
|
@ -3,55 +3,54 @@ DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
LITELLM_CHAT_PROVIDERS = [
|
LITELLM_CHAT_PROVIDERS = [
|
||||||
"openai",
|
# "openai",
|
||||||
"openai_like",
|
# "openai_like",
|
||||||
"xai",
|
# "xai",
|
||||||
"custom_openai",
|
# "custom_openai",
|
||||||
"text-completion-openai",
|
# "text-completion-openai",
|
||||||
"cohere",
|
# "cohere",
|
||||||
"cohere_chat",
|
# "cohere_chat",
|
||||||
"clarifai",
|
# "clarifai",
|
||||||
"anthropic",
|
# "anthropic",
|
||||||
"anthropic_text",
|
# "anthropic_text",
|
||||||
"replicate",
|
# "replicate",
|
||||||
"huggingface",
|
# "huggingface",
|
||||||
"together_ai",
|
# "together_ai",
|
||||||
"openrouter",
|
# "openrouter",
|
||||||
"vertex_ai",
|
# "vertex_ai",
|
||||||
"vertex_ai_beta",
|
# "vertex_ai_beta",
|
||||||
"palm",
|
# "palm",
|
||||||
"gemini",
|
# "gemini",
|
||||||
"ai21",
|
# "ai21",
|
||||||
"baseten",
|
# "baseten",
|
||||||
"azure",
|
# "azure",
|
||||||
"azure_text",
|
# "azure_text",
|
||||||
"azure_ai",
|
# "azure_ai",
|
||||||
"sagemaker",
|
# "sagemaker",
|
||||||
"sagemaker_chat",
|
# "sagemaker_chat",
|
||||||
"bedrock",
|
# "bedrock",
|
||||||
"vllm",
|
# "vllm",
|
||||||
"nlp_cloud",
|
# "nlp_cloud",
|
||||||
"petals",
|
# "petals",
|
||||||
"oobabooga",
|
# "oobabooga",
|
||||||
"ollama",
|
# "ollama",
|
||||||
"ollama_chat",
|
# "ollama_chat",
|
||||||
"deepinfra",
|
# "deepinfra",
|
||||||
"perplexity",
|
# "perplexity",
|
||||||
"anyscale",
|
# "anyscale",
|
||||||
"mistral",
|
# "mistral",
|
||||||
"groq",
|
# "groq",
|
||||||
"nvidia_nim",
|
# "nvidia_nim",
|
||||||
"cerebras",
|
# "cerebras",
|
||||||
"ai21_chat",
|
# "ai21_chat",
|
||||||
"volcengine",
|
# "volcengine",
|
||||||
"codestral",
|
# "codestral",
|
||||||
"text-completion-codestral",
|
# "text-completion-codestral",
|
||||||
"deepseek",
|
# "deepseek",
|
||||||
"sambanova",
|
# "sambanova",
|
||||||
"maritalk",
|
# "maritalk",
|
||||||
"voyage",
|
# "voyage",
|
||||||
"cloudflare",
|
# "cloudflare",
|
||||||
"xinference",
|
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"friendliai",
|
"friendliai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
|
|
@ -495,7 +495,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
api_base,
|
api_base,
|
||||||
dynamic_api_key,
|
dynamic_api_key,
|
||||||
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
||||||
model, api_base, api_key
|
model=model, api_base=api_base, api_key=api_key
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure_ai":
|
elif custom_llm_provider == "azure_ai":
|
||||||
(
|
(
|
||||||
|
|
|
@ -40,7 +40,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
return litellm.FireworksAIConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "nvidia_nim":
|
elif custom_llm_provider == "nvidia_nim":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -33,7 +34,7 @@ class BaseLLMException(Exception):
|
||||||
self,
|
self,
|
||||||
status_code: int,
|
status_code: int,
|
||||||
message: str,
|
message: str,
|
||||||
headers: Optional[httpx.Headers] = None,
|
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -3,10 +3,11 @@ from typing import Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig
|
from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig
|
||||||
|
|
||||||
|
|
||||||
class FireworksAIConfig:
|
class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
||||||
|
|
||||||
|
@ -56,23 +57,9 @@ class FireworksAIConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return super().get_config()
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self, model: str):
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
"tools",
|
"tools",
|
||||||
|
@ -98,8 +85,10 @@ class FireworksAIConfig:
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
model: str,
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
supported_openai_params = self.get_supported_openai_params()
|
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
if value == "required":
|
if value == "required":
|
24
litellm/llms/friendliai/chat/transformation.py
Normal file
24
litellm/llms/friendliai/chat/transformation.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
"""
|
||||||
|
Translate from OpenAI's `/v1/chat/completions` to Friendliai's `/v1/chat/completions`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...openai_like.chat.handler import OpenAILikeChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class FriendliaiChatConfig(OpenAILikeChatConfig):
|
||||||
|
pass
|
|
@ -19,7 +19,10 @@ from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
class OpenAILikeChatConfig(OpenAIGPTConfig):
|
class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
self, api_base: Optional[str], api_key: Optional[str]
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore
|
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore
|
||||||
dynamic_api_key = (
|
dynamic_api_key = (
|
||||||
|
|
|
@ -21,7 +21,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
if api_params.get("space_id") is None:
|
if api_params.get("space_id") is None:
|
||||||
raise WatsonXAIError(
|
raise WatsonXAIError(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
url=api_params["url"],
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
)
|
)
|
||||||
deployment_id = "/".join(model.split("/")[1:])
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
|
|
|
@ -1,24 +1,23 @@
|
||||||
from typing import Callable, Optional, cast
|
from typing import Callable, Dict, Optional, Union, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching import InMemoryCache
|
from litellm.caching import InMemoryCache
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAIError(Exception):
|
class WatsonXAIError(BaseLLMException):
|
||||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
def __init__(
|
||||||
self.status_code = status_code
|
self,
|
||||||
self.message = message
|
status_code: int,
|
||||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
message: str,
|
||||||
self.request = httpx.Request(method="POST", url=url)
|
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
):
|
||||||
super().__init__(
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
iam_token_cache = InMemoryCache()
|
iam_token_cache = InMemoryCache()
|
||||||
|
@ -151,13 +150,11 @@ def _get_api_params(
|
||||||
elif token is None and api_key is None:
|
elif token is None and api_key is None:
|
||||||
raise WatsonXAIError(
|
raise WatsonXAIError(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
url=url,
|
|
||||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||||
)
|
)
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
raise WatsonXAIError(
|
raise WatsonXAIError(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
url=url,
|
|
||||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -29,216 +29,14 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from ...prompt_templates import factory as ptf
|
from ...prompt_templates import factory as ptf
|
||||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||||
|
from .transformation import IBMWatsonXAIConfig
|
||||||
|
|
||||||
class IBMWatsonXAIConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
|
||||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
|
||||||
|
|
||||||
Supported params for all available watsonx.ai foundational models.
|
|
||||||
|
|
||||||
- `decoding_method` (str): One of "greedy" or "sample"
|
|
||||||
|
|
||||||
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
|
||||||
|
|
||||||
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
|
||||||
|
|
||||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
|
||||||
|
|
||||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
|
||||||
|
|
||||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
|
||||||
|
|
||||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
|
||||||
|
|
||||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
|
||||||
|
|
||||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
|
||||||
|
|
||||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
|
||||||
|
|
||||||
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
|
||||||
|
|
||||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
|
||||||
|
|
||||||
- `random_seed` (integer): Random seed for text generation.
|
|
||||||
|
|
||||||
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
|
||||||
|
|
||||||
- `stream` (bool): If True, the model will return a stream of responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
decoding_method: Optional[str] = "sample"
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
|
||||||
min_new_tokens: Optional[int] = None
|
|
||||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
|
||||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
repetition_penalty: Optional[float] = None
|
|
||||||
truncate_input_tokens: Optional[int] = None
|
|
||||||
include_stop_sequences: Optional[bool] = False
|
|
||||||
return_options: Optional[Dict[str, bool]] = None
|
|
||||||
random_seed: Optional[int] = None # e.g 42
|
|
||||||
moderations: Optional[dict] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
decoding_method: Optional[str] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
max_new_tokens: Optional[int] = None,
|
|
||||||
min_new_tokens: Optional[int] = None,
|
|
||||||
length_penalty: Optional[dict] = None,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
truncate_input_tokens: Optional[int] = None,
|
|
||||||
include_stop_sequences: Optional[bool] = None,
|
|
||||||
return_options: Optional[dict] = None,
|
|
||||||
random_seed: Optional[int] = None,
|
|
||||||
moderations: Optional[dict] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_watsonx_text_param(self, param: str) -> bool:
|
|
||||||
"""
|
|
||||||
Determine if user passed in a watsonx.ai text generation param
|
|
||||||
"""
|
|
||||||
text_generation_params = [
|
|
||||||
"decoding_method",
|
|
||||||
"max_new_tokens",
|
|
||||||
"min_new_tokens",
|
|
||||||
"length_penalty",
|
|
||||||
"stop_sequences",
|
|
||||||
"top_k",
|
|
||||||
"repetition_penalty",
|
|
||||||
"truncate_input_tokens",
|
|
||||||
"include_stop_sequences",
|
|
||||||
"return_options",
|
|
||||||
"random_seed",
|
|
||||||
"moderations",
|
|
||||||
"decoding_method",
|
|
||||||
"min_tokens",
|
|
||||||
]
|
|
||||||
|
|
||||||
return param in text_generation_params
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"temperature", # equivalent to temperature
|
|
||||||
"max_tokens", # equivalent to max_new_tokens
|
|
||||||
"top_p", # equivalent to top_p
|
|
||||||
"frequency_penalty", # equivalent to repetition_penalty
|
|
||||||
"stop", # equivalent to stop_sequences
|
|
||||||
"seed", # equivalent to random_seed
|
|
||||||
"stream", # equivalent to stream
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self, non_default_params: dict, optional_params: dict
|
|
||||||
) -> dict:
|
|
||||||
extra_body = {}
|
|
||||||
for k, v in non_default_params.items():
|
|
||||||
if k == "max_tokens":
|
|
||||||
optional_params["max_new_tokens"] = v
|
|
||||||
elif k == "stream":
|
|
||||||
optional_params["stream"] = v
|
|
||||||
elif k == "temperature":
|
|
||||||
optional_params["temperature"] = v
|
|
||||||
elif k == "top_p":
|
|
||||||
optional_params["top_p"] = v
|
|
||||||
elif k == "frequency_penalty":
|
|
||||||
optional_params["repetition_penalty"] = v
|
|
||||||
elif k == "seed":
|
|
||||||
optional_params["random_seed"] = v
|
|
||||||
elif k == "stop":
|
|
||||||
optional_params["stop_sequences"] = v
|
|
||||||
elif k == "decoding_method":
|
|
||||||
extra_body["decoding_method"] = v
|
|
||||||
elif k == "min_tokens":
|
|
||||||
extra_body["min_new_tokens"] = v
|
|
||||||
elif k == "top_k":
|
|
||||||
extra_body["top_k"] = v
|
|
||||||
elif k == "truncate_input_tokens":
|
|
||||||
extra_body["truncate_input_tokens"] = v
|
|
||||||
elif k == "length_penalty":
|
|
||||||
extra_body["length_penalty"] = v
|
|
||||||
elif k == "time_limit":
|
|
||||||
extra_body["time_limit"] = v
|
|
||||||
elif k == "return_options":
|
|
||||||
extra_body["return_options"] = v
|
|
||||||
|
|
||||||
if extra_body:
|
|
||||||
optional_params["extra_body"] = extra_body
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def get_mapped_special_auth_params(self) -> dict:
|
|
||||||
"""
|
|
||||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"project": "watsonx_project",
|
|
||||||
"region_name": "watsonx_region_name",
|
|
||||||
"token": "watsonx_token",
|
|
||||||
}
|
|
||||||
|
|
||||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
mapped_params = self.get_mapped_special_auth_params()
|
|
||||||
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param in mapped_params:
|
|
||||||
optional_params[mapped_params[param]] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def get_eu_regions(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
"eu-de",
|
|
||||||
"eu-gb",
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_us_regions(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
"us-south",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
||||||
|
@ -281,6 +79,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
def _prepare_text_generation_req(
|
def _prepare_text_generation_req(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
@ -293,11 +92,13 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
# build auth headers
|
# build auth headers
|
||||||
api_token = api_params.get("token")
|
api_token = api_params.get("token")
|
||||||
self.token = api_token
|
self.token = api_token
|
||||||
headers = {
|
headers = IBMWatsonXAIConfig().validate_environment(
|
||||||
"Authorization": f"Bearer {api_token}",
|
headers={},
|
||||||
"Content-Type": "application/json",
|
model=model_id,
|
||||||
"Accept": "application/json",
|
messages=messages,
|
||||||
}
|
optional_params=optional_params,
|
||||||
|
api_key=api_token,
|
||||||
|
)
|
||||||
extra_body_params = optional_params.pop("extra_body", {})
|
extra_body_params = optional_params.pop("extra_body", {})
|
||||||
optional_params.update(extra_body_params)
|
optional_params.update(extra_body_params)
|
||||||
# init the payload to the text generation call
|
# init the payload to the text generation call
|
||||||
|
@ -313,7 +114,6 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
if api_params.get("space_id") is None:
|
if api_params.get("space_id") is None:
|
||||||
raise WatsonXAIError(
|
raise WatsonXAIError(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
url=api_params["url"],
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
)
|
)
|
||||||
deployment_id = "/".join(model_id.split("/")[1:])
|
deployment_id = "/".join(model_id.split("/")[1:])
|
||||||
|
@ -466,6 +266,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
req_params = self._prepare_text_generation_req(
|
req_params = self._prepare_text_generation_req(
|
||||||
model_id=model,
|
model_id=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
|
299
litellm/llms/watsonx/completion/transformation.py
Normal file
299
litellm/llms/watsonx/completion/transformation.py
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
import asyncio
|
||||||
|
import json # noqa: E401
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
ContextManager,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||||
|
|
||||||
|
from ...base import BaseLLM
|
||||||
|
from ...base_llm.transformation import BaseConfig
|
||||||
|
from ...prompt_templates import factory as ptf
|
||||||
|
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class IBMWatsonXAIConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
|
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||||
|
|
||||||
|
Supported params for all available watsonx.ai foundational models.
|
||||||
|
|
||||||
|
- `decoding_method` (str): One of "greedy" or "sample"
|
||||||
|
|
||||||
|
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||||
|
|
||||||
|
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||||
|
|
||||||
|
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||||
|
|
||||||
|
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||||
|
|
||||||
|
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||||
|
|
||||||
|
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||||
|
|
||||||
|
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||||
|
|
||||||
|
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||||
|
|
||||||
|
- `random_seed` (integer): Random seed for text generation.
|
||||||
|
|
||||||
|
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||||
|
|
||||||
|
- `stream` (bool): If True, the model will return a stream of responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
decoding_method: Optional[str] = "sample"
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||||
|
min_new_tokens: Optional[int] = None
|
||||||
|
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||||
|
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
truncate_input_tokens: Optional[int] = None
|
||||||
|
include_stop_sequences: Optional[bool] = False
|
||||||
|
return_options: Optional[Dict[str, bool]] = None
|
||||||
|
random_seed: Optional[int] = None # e.g 42
|
||||||
|
moderations: Optional[dict] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoding_method: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
min_new_tokens: Optional[int] = None,
|
||||||
|
length_penalty: Optional[dict] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
truncate_input_tokens: Optional[int] = None,
|
||||||
|
include_stop_sequences: Optional[bool] = None,
|
||||||
|
return_options: Optional[dict] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
moderations: Optional[dict] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return super().get_config()
|
||||||
|
|
||||||
|
def is_watsonx_text_param(self, param: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if user passed in a watsonx.ai text generation param
|
||||||
|
"""
|
||||||
|
text_generation_params = [
|
||||||
|
"decoding_method",
|
||||||
|
"max_new_tokens",
|
||||||
|
"min_new_tokens",
|
||||||
|
"length_penalty",
|
||||||
|
"stop_sequences",
|
||||||
|
"top_k",
|
||||||
|
"repetition_penalty",
|
||||||
|
"truncate_input_tokens",
|
||||||
|
"include_stop_sequences",
|
||||||
|
"return_options",
|
||||||
|
"random_seed",
|
||||||
|
"moderations",
|
||||||
|
"decoding_method",
|
||||||
|
"min_tokens",
|
||||||
|
]
|
||||||
|
|
||||||
|
return param in text_generation_params
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str):
|
||||||
|
return [
|
||||||
|
"temperature", # equivalent to temperature
|
||||||
|
"max_tokens", # equivalent to max_new_tokens
|
||||||
|
"top_p", # equivalent to top_p
|
||||||
|
"frequency_penalty", # equivalent to repetition_penalty
|
||||||
|
"stop", # equivalent to stop_sequences
|
||||||
|
"seed", # equivalent to random_seed
|
||||||
|
"stream", # equivalent to stream
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> Dict:
|
||||||
|
extra_body = {}
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "max_tokens":
|
||||||
|
optional_params["max_new_tokens"] = v
|
||||||
|
elif k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
elif k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
elif k == "top_p":
|
||||||
|
optional_params["top_p"] = v
|
||||||
|
elif k == "frequency_penalty":
|
||||||
|
optional_params["repetition_penalty"] = v
|
||||||
|
elif k == "seed":
|
||||||
|
optional_params["random_seed"] = v
|
||||||
|
elif k == "stop":
|
||||||
|
optional_params["stop_sequences"] = v
|
||||||
|
elif k == "decoding_method":
|
||||||
|
extra_body["decoding_method"] = v
|
||||||
|
elif k == "min_tokens":
|
||||||
|
extra_body["min_new_tokens"] = v
|
||||||
|
elif k == "top_k":
|
||||||
|
extra_body["top_k"] = v
|
||||||
|
elif k == "truncate_input_tokens":
|
||||||
|
extra_body["truncate_input_tokens"] = v
|
||||||
|
elif k == "length_penalty":
|
||||||
|
extra_body["length_penalty"] = v
|
||||||
|
elif k == "time_limit":
|
||||||
|
extra_body["time_limit"] = v
|
||||||
|
elif k == "return_options":
|
||||||
|
extra_body["return_options"] = v
|
||||||
|
|
||||||
|
if extra_body:
|
||||||
|
optional_params["extra_body"] = extra_body
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"project": "watsonx_project",
|
||||||
|
"region_name": "watsonx_region_name",
|
||||||
|
"token": "watsonx_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"eu-de",
|
||||||
|
"eu-gb",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_us_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"us-south",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return WatsonXAIError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
headers: Dict,
|
||||||
|
) -> Dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"transform_request not implemented. Done in watsonx/completion handler.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
encoding: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"transform_response not implemented. Done in watsonx/completion handler.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
|
@ -3515,6 +3515,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3658,6 +3663,12 @@ def get_optional_params( # noqa: PLR0915
|
||||||
optional_params = litellm.IBMWatsonXAIConfig().map_openai_params(
|
optional_params = litellm.IBMWatsonXAIConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -6284,7 +6295,14 @@ class ProviderConfigManager:
|
||||||
return litellm.VertexAIAnthropicConfig()
|
return litellm.VertexAIAnthropicConfig()
|
||||||
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||||
return litellm.CloudflareChatConfig()
|
return litellm.CloudflareChatConfig()
|
||||||
|
elif litellm.LlmProviders.FIREWORKS_AI == provider:
|
||||||
|
return litellm.FireworksAIConfig()
|
||||||
|
elif litellm.LlmProviders.FRIENDLIAI == provider:
|
||||||
|
return litellm.FriendliaiChatConfig()
|
||||||
|
elif litellm.LlmProviders.WATSONX == provider:
|
||||||
|
return litellm.IBMWatsonXChatConfig()
|
||||||
|
elif litellm.LlmProviders.WATSONX_TEXT == provider:
|
||||||
|
return litellm.IBMWatsonXAIConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
fireworks = FireworksAIConfig()
|
fireworks = FireworksAIConfig()
|
||||||
|
@ -15,21 +15,27 @@ fireworks = FireworksAIConfig()
|
||||||
|
|
||||||
def test_map_openai_params_tool_choice():
|
def test_map_openai_params_tool_choice():
|
||||||
# Test case 1: tool_choice is "required"
|
# Test case 1: tool_choice is "required"
|
||||||
result = fireworks.map_openai_params({"tool_choice": "required"}, {}, "some_model")
|
result = fireworks.map_openai_params(
|
||||||
|
{"tool_choice": "required"}, {}, "some_model", drop_params=False
|
||||||
|
)
|
||||||
assert result == {"tool_choice": "any"}
|
assert result == {"tool_choice": "any"}
|
||||||
|
|
||||||
# Test case 2: tool_choice is "auto"
|
# Test case 2: tool_choice is "auto"
|
||||||
result = fireworks.map_openai_params({"tool_choice": "auto"}, {}, "some_model")
|
result = fireworks.map_openai_params(
|
||||||
|
{"tool_choice": "auto"}, {}, "some_model", drop_params=False
|
||||||
|
)
|
||||||
assert result == {"tool_choice": "auto"}
|
assert result == {"tool_choice": "auto"}
|
||||||
|
|
||||||
# Test case 3: tool_choice is not present
|
# Test case 3: tool_choice is not present
|
||||||
result = fireworks.map_openai_params(
|
result = fireworks.map_openai_params(
|
||||||
{"some_other_param": "value"}, {}, "some_model"
|
{"some_other_param": "value"}, {}, "some_model", drop_params=False
|
||||||
)
|
)
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
# Test case 4: tool_choice is None
|
# Test case 4: tool_choice is None
|
||||||
result = fireworks.map_openai_params({"tool_choice": None}, {}, "some_model")
|
result = fireworks.map_openai_params(
|
||||||
|
{"tool_choice": None}, {}, "some_model", drop_params=False
|
||||||
|
)
|
||||||
assert result == {"tool_choice": None}
|
assert result == {"tool_choice": None}
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +61,7 @@ def test_map_response_format():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
result = fireworks.map_openai_params(
|
result = fireworks.map_openai_params(
|
||||||
{"response_format": response_format}, {}, "some_model"
|
{"response_format": response_format}, {}, "some_model", drop_params=False
|
||||||
)
|
)
|
||||||
assert result == {
|
assert result == {
|
||||||
"response_format": {
|
"response_format": {
|
||||||
|
|
|
@ -154,13 +154,18 @@ def test_all_model_configs():
|
||||||
{"max_completion_tokens": 10}, {}, "llama3", drop_params=False
|
{"max_completion_tokens": 10}, {}, "llama3", drop_params=False
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
|
from litellm.llms.fireworks_ai.chat.transformation import (
|
||||||
FireworksAIConfig,
|
FireworksAIConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params(
|
||||||
|
model="llama3"
|
||||||
|
)
|
||||||
assert FireworksAIConfig().map_openai_params(
|
assert FireworksAIConfig().map_openai_params(
|
||||||
{"max_completion_tokens": 10}, {}, "llama3"
|
model="llama3",
|
||||||
|
non_default_params={"max_completion_tokens": 10},
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.huggingface_restapi import HuggingfaceConfig
|
from litellm.llms.huggingface_restapi import HuggingfaceConfig
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue