diff --git a/litellm/__init__.py b/litellm/__init__.py index 87be1d002f..f119dde2a8 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1163,10 +1163,11 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig() from .llms.cerebras.chat import CerebrasConfig from .llms.sambanova.chat import SambanovaConfig from .llms.ai21.chat import AI21ChatConfig -from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig +from .llms.fireworks_ai.chat.transformation import FireworksAIConfig from .llms.fireworks_ai.embed.fireworks_ai_transformation import ( FireworksAIEmbeddingConfig, ) +from .llms.friendliai.chat.transformation import FriendliaiChatConfig from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig from .llms.xai.chat.transformation import XAIChatConfig from .llms.volcengine import VolcEngineConfig @@ -1183,7 +1184,7 @@ from .llms.lm_studio.chat.transformation import LMStudioChatConfig from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig from .llms.perplexity.chat.transformation import PerplexityChatConfig from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config -from .llms.watsonx.completion.handler import IBMWatsonXAIConfig +from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig from .main import * # type: ignore from .integrations import * diff --git a/litellm/constants.py b/litellm/constants.py index 1fb97e07fc..b60492fb88 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -3,55 +3,54 @@ DEFAULT_BATCH_SIZE = 512 DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_MAX_RETRIES = 2 LITELLM_CHAT_PROVIDERS = [ - "openai", - "openai_like", - "xai", - "custom_openai", - "text-completion-openai", - "cohere", - "cohere_chat", - "clarifai", - "anthropic", - "anthropic_text", - "replicate", - "huggingface", - "together_ai", - "openrouter", - "vertex_ai", - "vertex_ai_beta", - "palm", - "gemini", - "ai21", - "baseten", - "azure", - "azure_text", - "azure_ai", - "sagemaker", - "sagemaker_chat", - "bedrock", - "vllm", - "nlp_cloud", - "petals", - "oobabooga", - "ollama", - "ollama_chat", - "deepinfra", - "perplexity", - "anyscale", - "mistral", - "groq", - "nvidia_nim", - "cerebras", - "ai21_chat", - "volcengine", - "codestral", - "text-completion-codestral", - "deepseek", - "sambanova", - "maritalk", - "voyage", - "cloudflare", - "xinference", + # "openai", + # "openai_like", + # "xai", + # "custom_openai", + # "text-completion-openai", + # "cohere", + # "cohere_chat", + # "clarifai", + # "anthropic", + # "anthropic_text", + # "replicate", + # "huggingface", + # "together_ai", + # "openrouter", + # "vertex_ai", + # "vertex_ai_beta", + # "palm", + # "gemini", + # "ai21", + # "baseten", + # "azure", + # "azure_text", + # "azure_ai", + # "sagemaker", + # "sagemaker_chat", + # "bedrock", + # "vllm", + # "nlp_cloud", + # "petals", + # "oobabooga", + # "ollama", + # "ollama_chat", + # "deepinfra", + # "perplexity", + # "anyscale", + # "mistral", + # "groq", + # "nvidia_nim", + # "cerebras", + # "ai21_chat", + # "volcengine", + # "codestral", + # "text-completion-codestral", + # "deepseek", + # "sambanova", + # "maritalk", + # "voyage", + # "cloudflare", "fireworks_ai", "friendliai", "watsonx", diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 68992361ad..522068d571 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -495,7 +495,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 api_base, dynamic_api_key, ) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info( - model, api_base, api_key + model=model, api_base=api_base, api_key=api_key ) elif custom_llm_provider == "azure_ai": ( diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 383c2490c0..f6a024aa7d 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -40,7 +40,7 @@ def get_supported_openai_params( # noqa: PLR0915 model=model ) else: - return litellm.FireworksAIConfig().get_supported_openai_params() + return litellm.FireworksAIConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "nvidia_nim": if request_type == "chat_completion": return litellm.nvidiaNimConfig.get_supported_openai_params(model=model) diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/transformation.py index 06d392e0b0..2110e60b56 100644 --- a/litellm/llms/base_llm/transformation.py +++ b/litellm/llms/base_llm/transformation.py @@ -9,6 +9,7 @@ from typing import ( Any, AsyncIterator, Callable, + Dict, Iterator, List, Optional, @@ -33,7 +34,7 @@ class BaseLLMException(Exception): self, status_code: int, message: str, - headers: Optional[httpx.Headers] = None, + headers: Optional[Union[Dict, httpx.Headers]] = None, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, ): diff --git a/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py similarity index 90% rename from litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py rename to litellm/llms/fireworks_ai/chat/transformation.py index f1acca6084..d1a49b605f 100644 --- a/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -3,10 +3,11 @@ from typing import Literal, Optional, Tuple, Union from litellm.secret_managers.main import get_secret_str +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig -class FireworksAIConfig: +class FireworksAIConfig(OpenAIGPTConfig): """ Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions @@ -56,23 +57,9 @@ class FireworksAIConfig: @classmethod def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } + return super().get_config() - def get_supported_openai_params(self): + def get_supported_openai_params(self, model: str): return [ "stream", "tools", @@ -98,8 +85,10 @@ class FireworksAIConfig: non_default_params: dict, optional_params: dict, model: str, + drop_params: bool, ) -> dict: - supported_openai_params = self.get_supported_openai_params() + + supported_openai_params = self.get_supported_openai_params(model=model) for param, value in non_default_params.items(): if param == "tool_choice": if value == "required": diff --git a/litellm/llms/friendliai/chat/transformation.py b/litellm/llms/friendliai/chat/transformation.py new file mode 100644 index 0000000000..02bb4c7f29 --- /dev/null +++ b/litellm/llms/friendliai/chat/transformation.py @@ -0,0 +1,24 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to Friendliai's `/v1/chat/completions` +""" + +import json +import types +from typing import List, Optional, Tuple, Union + +from pydantic import BaseModel + +import litellm +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) + +from ...openai_like.chat.handler import OpenAILikeChatConfig + + +class FriendliaiChatConfig(OpenAILikeChatConfig): + pass diff --git a/litellm/llms/openai_like/chat/transformation.py b/litellm/llms/openai_like/chat/transformation.py index c355cf3303..2be71596aa 100644 --- a/litellm/llms/openai_like/chat/transformation.py +++ b/litellm/llms/openai_like/chat/transformation.py @@ -19,7 +19,10 @@ from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig class OpenAILikeChatConfig(OpenAIGPTConfig): def _get_openai_compatible_provider_info( - self, api_base: Optional[str], api_key: Optional[str] + self, + api_base: Optional[str], + api_key: Optional[str], + model: Optional[str] = None, ) -> Tuple[Optional[str], Optional[str]]: api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore dynamic_api_key = ( diff --git a/litellm/llms/watsonx/chat/handler.py b/litellm/llms/watsonx/chat/handler.py index 932946d3cb..7753b929a3 100644 --- a/litellm/llms/watsonx/chat/handler.py +++ b/litellm/llms/watsonx/chat/handler.py @@ -21,7 +21,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler): if api_params.get("space_id") is None: raise WatsonXAIError( status_code=401, - url=api_params["url"], message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", ) deployment_id = "/".join(model.split("/")[1:]) diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index 976b8e6dd1..e8ddc5f328 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -1,24 +1,23 @@ -from typing import Callable, Optional, cast +from typing import Callable, Dict, Optional, Union, cast import httpx import litellm from litellm import verbose_logger from litellm.caching import InMemoryCache +from litellm.llms.base_llm.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.watsonx import WatsonXAPIParams -class WatsonXAIError(Exception): - def __init__(self, status_code, message, url: Optional[str] = None): - self.status_code = status_code - self.message = message - url = url or "https://https://us-south.ml.cloud.ibm.com" - self.request = httpx.Request(method="POST", url=url) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs +class WatsonXAIError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[Dict, httpx.Headers]] = None, + ): + super().__init__(status_code=status_code, message=message, headers=headers) iam_token_cache = InMemoryCache() @@ -151,13 +150,11 @@ def _get_api_params( elif token is None and api_key is None: raise WatsonXAIError( status_code=401, - url=url, message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.", ) if project_id is None: raise WatsonXAIError( status_code=401, - url=url, message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", ) diff --git a/litellm/llms/watsonx/completion/handler.py b/litellm/llms/watsonx/completion/handler.py index 9618f6342b..9cd884e3f6 100644 --- a/litellm/llms/watsonx/completion/handler.py +++ b/litellm/llms/watsonx/completion/handler.py @@ -29,216 +29,14 @@ from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, ) from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.watsonx import WatsonXAIEndpoint from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason from ...base import BaseLLM from ...prompt_templates import factory as ptf from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token - - -class IBMWatsonXAIConfig: - """ - Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation - (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) - - Supported params for all available watsonx.ai foundational models. - - - `decoding_method` (str): One of "greedy" or "sample" - - - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. - - - `max_new_tokens` (integer): Maximum length of the generated tokens. - - - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. - - - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". - - - `stop_sequences` (string[]): list of strings to use as stop sequences. - - - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. - - - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. - - - `repetition_penalty` (float): token repetition penalty during text generation. - - - `truncate_input_tokens` (integer): Truncate input tokens to this length. - - - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. - - - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. - - - `random_seed` (integer): Random seed for text generation. - - - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. - - - `stream` (bool): If True, the model will return a stream of responses. - """ - - decoding_method: Optional[str] = "sample" - temperature: Optional[float] = None - max_new_tokens: Optional[int] = None # litellm.max_tokens - min_new_tokens: Optional[int] = None - length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} - stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] - top_k: Optional[int] = None - top_p: Optional[float] = None - repetition_penalty: Optional[float] = None - truncate_input_tokens: Optional[int] = None - include_stop_sequences: Optional[bool] = False - return_options: Optional[Dict[str, bool]] = None - random_seed: Optional[int] = None # e.g 42 - moderations: Optional[dict] = None - stream: Optional[bool] = False - - def __init__( - self, - decoding_method: Optional[str] = None, - temperature: Optional[float] = None, - max_new_tokens: Optional[int] = None, - min_new_tokens: Optional[int] = None, - length_penalty: Optional[dict] = None, - stop_sequences: Optional[List[str]] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - truncate_input_tokens: Optional[int] = None, - include_stop_sequences: Optional[bool] = None, - return_options: Optional[dict] = None, - random_seed: Optional[int] = None, - moderations: Optional[dict] = None, - stream: Optional[bool] = None, - **kwargs, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def is_watsonx_text_param(self, param: str) -> bool: - """ - Determine if user passed in a watsonx.ai text generation param - """ - text_generation_params = [ - "decoding_method", - "max_new_tokens", - "min_new_tokens", - "length_penalty", - "stop_sequences", - "top_k", - "repetition_penalty", - "truncate_input_tokens", - "include_stop_sequences", - "return_options", - "random_seed", - "moderations", - "decoding_method", - "min_tokens", - ] - - return param in text_generation_params - - def get_supported_openai_params(self): - return [ - "temperature", # equivalent to temperature - "max_tokens", # equivalent to max_new_tokens - "top_p", # equivalent to top_p - "frequency_penalty", # equivalent to repetition_penalty - "stop", # equivalent to stop_sequences - "seed", # equivalent to random_seed - "stream", # equivalent to stream - ] - - def map_openai_params( - self, non_default_params: dict, optional_params: dict - ) -> dict: - extra_body = {} - for k, v in non_default_params.items(): - if k == "max_tokens": - optional_params["max_new_tokens"] = v - elif k == "stream": - optional_params["stream"] = v - elif k == "temperature": - optional_params["temperature"] = v - elif k == "top_p": - optional_params["top_p"] = v - elif k == "frequency_penalty": - optional_params["repetition_penalty"] = v - elif k == "seed": - optional_params["random_seed"] = v - elif k == "stop": - optional_params["stop_sequences"] = v - elif k == "decoding_method": - extra_body["decoding_method"] = v - elif k == "min_tokens": - extra_body["min_new_tokens"] = v - elif k == "top_k": - extra_body["top_k"] = v - elif k == "truncate_input_tokens": - extra_body["truncate_input_tokens"] = v - elif k == "length_penalty": - extra_body["length_penalty"] = v - elif k == "time_limit": - extra_body["time_limit"] = v - elif k == "return_options": - extra_body["return_options"] = v - - if extra_body: - optional_params["extra_body"] = extra_body - return optional_params - - def get_mapped_special_auth_params(self) -> dict: - """ - Common auth params across bedrock/vertex_ai/azure/watsonx - """ - return { - "project": "watsonx_project", - "region_name": "watsonx_region_name", - "token": "watsonx_token", - } - - def map_special_auth_params(self, non_default_params: dict, optional_params: dict): - mapped_params = self.get_mapped_special_auth_params() - - for param, value in non_default_params.items(): - if param in mapped_params: - optional_params[mapped_params[param]] = value - return optional_params - - def get_eu_regions(self) -> List[str]: - """ - Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability - """ - return [ - "eu-de", - "eu-gb", - ] - - def get_us_regions(self) -> List[str]: - """ - Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability - """ - return [ - "us-south", - ] +from .transformation import IBMWatsonXAIConfig def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str: @@ -281,6 +79,7 @@ class IBMWatsonXAI(BaseLLM): def _prepare_text_generation_req( self, model_id: str, + messages: List[AllMessageValues], prompt: str, stream: bool, optional_params: dict, @@ -293,11 +92,13 @@ class IBMWatsonXAI(BaseLLM): # build auth headers api_token = api_params.get("token") self.token = api_token - headers = { - "Authorization": f"Bearer {api_token}", - "Content-Type": "application/json", - "Accept": "application/json", - } + headers = IBMWatsonXAIConfig().validate_environment( + headers={}, + model=model_id, + messages=messages, + optional_params=optional_params, + api_key=api_token, + ) extra_body_params = optional_params.pop("extra_body", {}) optional_params.update(extra_body_params) # init the payload to the text generation call @@ -313,7 +114,6 @@ class IBMWatsonXAI(BaseLLM): if api_params.get("space_id") is None: raise WatsonXAIError( status_code=401, - url=api_params["url"], message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", ) deployment_id = "/".join(model_id.split("/")[1:]) @@ -466,6 +266,7 @@ class IBMWatsonXAI(BaseLLM): req_params = self._prepare_text_generation_req( model_id=model, prompt=prompt, + messages=messages, stream=stream, optional_params=optional_params, print_verbose=print_verbose, diff --git a/litellm/llms/watsonx/completion/transformation.py b/litellm/llms/watsonx/completion/transformation.py new file mode 100644 index 0000000000..ab26890e00 --- /dev/null +++ b/litellm/llms/watsonx/completion/transformation.py @@ -0,0 +1,299 @@ +import asyncio +import json # noqa: E401 +import time +import types +from contextlib import asynccontextmanager, contextmanager +from datetime import datetime +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + AsyncGenerator, + AsyncIterator, + Callable, + ContextManager, + Dict, + Generator, + Iterator, + List, + Optional, + Union, +) + +import httpx + +import litellm +from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.watsonx import WatsonXAIEndpoint +from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason + +from ...base import BaseLLM +from ...base_llm.transformation import BaseConfig +from ...prompt_templates import factory as ptf +from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class IBMWatsonXAIConfig(BaseConfig): + """ + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) + + Supported params for all available watsonx.ai foundational models. + + - `decoding_method` (str): One of "greedy" or "sample" + + - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. + + - `max_new_tokens` (integer): Maximum length of the generated tokens. + + - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + + - `stop_sequences` (string[]): list of strings to use as stop sequences. + + - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + + - `repetition_penalty` (float): token repetition penalty during text generation. + + - `truncate_input_tokens` (integer): Truncate input tokens to this length. + + - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. + + - `random_seed` (integer): Random seed for text generation. + + - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. + + - `stream` (bool): If True, the model will return a stream of responses. + """ + + decoding_method: Optional[str] = "sample" + temperature: Optional[float] = None + max_new_tokens: Optional[int] = None # litellm.max_tokens + min_new_tokens: Optional[int] = None + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + truncate_input_tokens: Optional[int] = None + include_stop_sequences: Optional[bool] = False + return_options: Optional[Dict[str, bool]] = None + random_seed: Optional[int] = None # e.g 42 + moderations: Optional[dict] = None + stream: Optional[bool] = False + + def __init__( + self, + decoding_method: Optional[str] = None, + temperature: Optional[float] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + length_penalty: Optional[dict] = None, + stop_sequences: Optional[List[str]] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + truncate_input_tokens: Optional[int] = None, + include_stop_sequences: Optional[bool] = None, + return_options: Optional[dict] = None, + random_seed: Optional[int] = None, + moderations: Optional[dict] = None, + stream: Optional[bool] = None, + **kwargs, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def is_watsonx_text_param(self, param: str) -> bool: + """ + Determine if user passed in a watsonx.ai text generation param + """ + text_generation_params = [ + "decoding_method", + "max_new_tokens", + "min_new_tokens", + "length_penalty", + "stop_sequences", + "top_k", + "repetition_penalty", + "truncate_input_tokens", + "include_stop_sequences", + "return_options", + "random_seed", + "moderations", + "decoding_method", + "min_tokens", + ] + + return param in text_generation_params + + def get_supported_openai_params(self, model: str): + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + ] + + def map_openai_params( + self, + non_default_params: Dict, + optional_params: Dict, + model: str, + drop_params: bool, + ) -> Dict: + extra_body = {} + for k, v in non_default_params.items(): + if k == "max_tokens": + optional_params["max_new_tokens"] = v + elif k == "stream": + optional_params["stream"] = v + elif k == "temperature": + optional_params["temperature"] = v + elif k == "top_p": + optional_params["top_p"] = v + elif k == "frequency_penalty": + optional_params["repetition_penalty"] = v + elif k == "seed": + optional_params["random_seed"] = v + elif k == "stop": + optional_params["stop_sequences"] = v + elif k == "decoding_method": + extra_body["decoding_method"] = v + elif k == "min_tokens": + extra_body["min_new_tokens"] = v + elif k == "top_k": + extra_body["top_k"] = v + elif k == "truncate_input_tokens": + extra_body["truncate_input_tokens"] = v + elif k == "length_penalty": + extra_body["length_penalty"] = v + elif k == "time_limit": + extra_body["time_limit"] = v + elif k == "return_options": + extra_body["return_options"] = v + + if extra_body: + optional_params["extra_body"] = extra_body + return optional_params + + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return { + "project": "watsonx_project", + "region_name": "watsonx_region_name", + "token": "watsonx_token", + } + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability + """ + return [ + "eu-de", + "eu-gb", + ] + + def get_us_regions(self) -> List[str]: + """ + Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability + """ + return [ + "us-south", + ] + + def _transform_messages( + self, + messages: List[AllMessageValues], + ) -> List[AllMessageValues]: + return messages + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] + ) -> BaseLLMException: + return WatsonXAIError( + status_code=status_code, message=error_message, headers=headers + ) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + headers: Dict, + ) -> Dict: + raise NotImplementedError( + "transform_request not implemented. Done in watsonx/completion handler.py" + ) + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "transform_response not implemented. Done in watsonx/completion handler.py" + ) + + def validate_environment( + self, + headers: Dict, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + ) -> Dict: + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers diff --git a/litellm/utils.py b/litellm/utils.py index 5321357a87..94f9a41276 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3515,6 +3515,11 @@ def get_optional_params( # noqa: PLR0915 non_default_params=non_default_params, optional_params=optional_params, model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "volcengine": supported_params = get_supported_openai_params( @@ -3658,6 +3663,12 @@ def get_optional_params( # noqa: PLR0915 optional_params = litellm.IBMWatsonXAIConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "openai": supported_params = get_supported_openai_params( @@ -6284,7 +6295,14 @@ class ProviderConfigManager: return litellm.VertexAIAnthropicConfig() elif litellm.LlmProviders.CLOUDFLARE == provider: return litellm.CloudflareChatConfig() - + elif litellm.LlmProviders.FIREWORKS_AI == provider: + return litellm.FireworksAIConfig() + elif litellm.LlmProviders.FRIENDLIAI == provider: + return litellm.FriendliaiChatConfig() + elif litellm.LlmProviders.WATSONX == provider: + return litellm.IBMWatsonXChatConfig() + elif litellm.LlmProviders.WATSONX_TEXT == provider: + return litellm.IBMWatsonXAIConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/llm_translation/test_fireworks_ai_translation.py b/tests/llm_translation/test_fireworks_ai_translation.py index 1edc79568c..660c96cf15 100644 --- a/tests/llm_translation/test_fireworks_ai_translation.py +++ b/tests/llm_translation/test_fireworks_ai_translation.py @@ -7,7 +7,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig +from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig from base_llm_unit_tests import BaseLLMChatTest fireworks = FireworksAIConfig() @@ -15,21 +15,27 @@ fireworks = FireworksAIConfig() def test_map_openai_params_tool_choice(): # Test case 1: tool_choice is "required" - result = fireworks.map_openai_params({"tool_choice": "required"}, {}, "some_model") + result = fireworks.map_openai_params( + {"tool_choice": "required"}, {}, "some_model", drop_params=False + ) assert result == {"tool_choice": "any"} # Test case 2: tool_choice is "auto" - result = fireworks.map_openai_params({"tool_choice": "auto"}, {}, "some_model") + result = fireworks.map_openai_params( + {"tool_choice": "auto"}, {}, "some_model", drop_params=False + ) assert result == {"tool_choice": "auto"} # Test case 3: tool_choice is not present result = fireworks.map_openai_params( - {"some_other_param": "value"}, {}, "some_model" + {"some_other_param": "value"}, {}, "some_model", drop_params=False ) assert result == {} # Test case 4: tool_choice is None - result = fireworks.map_openai_params({"tool_choice": None}, {}, "some_model") + result = fireworks.map_openai_params( + {"tool_choice": None}, {}, "some_model", drop_params=False + ) assert result == {"tool_choice": None} @@ -55,7 +61,7 @@ def test_map_response_format(): }, } result = fireworks.map_openai_params( - {"response_format": response_format}, {}, "some_model" + {"response_format": response_format}, {}, "some_model", drop_params=False ) assert result == { "response_format": { diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index 0b1e9b71a0..4452bd0fc9 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -154,13 +154,18 @@ def test_all_model_configs(): {"max_completion_tokens": 10}, {}, "llama3", drop_params=False ) == {"max_tokens": 10} - from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import ( + from litellm.llms.fireworks_ai.chat.transformation import ( FireworksAIConfig, ) - assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params() + assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params( + model="llama3" + ) assert FireworksAIConfig().map_openai_params( - {"max_completion_tokens": 10}, {}, "llama3" + model="llama3", + non_default_params={"max_completion_tokens": 10}, + optional_params={}, + drop_params=False, ) == {"max_tokens": 10} from litellm.llms.huggingface_restapi import HuggingfaceConfig