refactor(fireworks_ai/): inherit from openai like base config (#7146)

* refactor(fireworks_ai/): inherit from openai like base config

refactors fireworks ai to use a common config

* test: fix import in test

* refactor(watsonx/): refactor watsonx to use llm base config

refactors chat + completion routes to base config path

* fix: fix linting error

* test: fix test

* fix: fix test
This commit is contained in:
Krish Dholakia 2024-12-10 16:15:19 -08:00 committed by GitHub
parent 6a9225fac2
commit 4eeaaeeacd
15 changed files with 449 additions and 307 deletions

View file

@ -1163,10 +1163,11 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig
from .llms.ai21.chat import AI21ChatConfig
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
FireworksAIEmbeddingConfig,
)
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
from .llms.xai.chat.transformation import XAIChatConfig
from .llms.volcengine import VolcEngineConfig
@ -1183,7 +1184,7 @@ from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .main import * # type: ignore
from .integrations import *

View file

@ -3,55 +3,54 @@ DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
LITELLM_CHAT_PROVIDERS = [
"openai",
"openai_like",
"xai",
"custom_openai",
"text-completion-openai",
"cohere",
"cohere_chat",
"clarifai",
"anthropic",
"anthropic_text",
"replicate",
"huggingface",
"together_ai",
"openrouter",
"vertex_ai",
"vertex_ai_beta",
"palm",
"gemini",
"ai21",
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"sagemaker_chat",
"bedrock",
"vllm",
"nlp_cloud",
"petals",
"oobabooga",
"ollama",
"ollama_chat",
"deepinfra",
"perplexity",
"anyscale",
"mistral",
"groq",
"nvidia_nim",
"cerebras",
"ai21_chat",
"volcengine",
"codestral",
"text-completion-codestral",
"deepseek",
"sambanova",
"maritalk",
"voyage",
"cloudflare",
"xinference",
# "openai",
# "openai_like",
# "xai",
# "custom_openai",
# "text-completion-openai",
# "cohere",
# "cohere_chat",
# "clarifai",
# "anthropic",
# "anthropic_text",
# "replicate",
# "huggingface",
# "together_ai",
# "openrouter",
# "vertex_ai",
# "vertex_ai_beta",
# "palm",
# "gemini",
# "ai21",
# "baseten",
# "azure",
# "azure_text",
# "azure_ai",
# "sagemaker",
# "sagemaker_chat",
# "bedrock",
# "vllm",
# "nlp_cloud",
# "petals",
# "oobabooga",
# "ollama",
# "ollama_chat",
# "deepinfra",
# "perplexity",
# "anyscale",
# "mistral",
# "groq",
# "nvidia_nim",
# "cerebras",
# "ai21_chat",
# "volcengine",
# "codestral",
# "text-completion-codestral",
# "deepseek",
# "sambanova",
# "maritalk",
# "voyage",
# "cloudflare",
"fireworks_ai",
"friendliai",
"watsonx",

View file

@ -495,7 +495,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
api_base,
dynamic_api_key,
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
model, api_base, api_key
model=model, api_base=api_base, api_key=api_key
)
elif custom_llm_provider == "azure_ai":
(

View file

@ -40,7 +40,7 @@ def get_supported_openai_params( # noqa: PLR0915
model=model
)
else:
return litellm.FireworksAIConfig().get_supported_openai_params()
return litellm.FireworksAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "nvidia_nim":
if request_type == "chat_completion":
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)

View file

@ -9,6 +9,7 @@ from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
@ -33,7 +34,7 @@ class BaseLLMException(Exception):
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
headers: Optional[Union[Dict, httpx.Headers]] = None,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):

View file

@ -3,10 +3,11 @@ from typing import Literal, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig
class FireworksAIConfig:
class FireworksAIConfig(OpenAIGPTConfig):
"""
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
@ -56,23 +57,9 @@ class FireworksAIConfig:
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
return super().get_config()
def get_supported_openai_params(self):
def get_supported_openai_params(self, model: str):
return [
"stream",
"tools",
@ -98,8 +85,10 @@ class FireworksAIConfig:
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params()
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "tool_choice":
if value == "required":

View file

@ -0,0 +1,24 @@
"""
Translate from OpenAI's `/v1/chat/completions` to Friendliai's `/v1/chat/completions`
"""
import json
import types
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from ...openai_like.chat.handler import OpenAILikeChatConfig
class FriendliaiChatConfig(OpenAILikeChatConfig):
pass

View file

@ -19,7 +19,10 @@ from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class OpenAILikeChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
self,
api_base: Optional[str],
api_key: Optional[str],
model: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore
dynamic_api_key = (

View file

@ -21,7 +21,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
url=api_params["url"],
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])

View file

@ -1,24 +1,23 @@
from typing import Callable, Optional, cast
from typing import Callable, Dict, Optional, Union, cast
import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams
class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code
self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com"
self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class WatsonXAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[Dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
iam_token_cache = InMemoryCache()
@ -151,13 +150,11 @@ def _get_api_params(
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)

View file

@ -29,216 +29,14 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM
from ...prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
class IBMWatsonXAIConfig:
"""
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
Supported params for all available watsonx.ai foundational models.
- `decoding_method` (str): One of "greedy" or "sample"
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
- `max_new_tokens` (integer): Maximum length of the generated tokens.
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
- `stop_sequences` (string[]): list of strings to use as stop sequences.
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
- `repetition_penalty` (float): token repetition penalty during text generation.
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
- `random_seed` (integer): Random seed for text generation.
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
- `stream` (bool): If True, the model will return a stream of responses.
"""
decoding_method: Optional[str] = "sample"
temperature: Optional[float] = None
max_new_tokens: Optional[int] = None # litellm.max_tokens
min_new_tokens: Optional[int] = None
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False
return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None
stream: Optional[bool] = False
def __init__(
self,
decoding_method: Optional[str] = None,
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
length_penalty: Optional[dict] = None,
stop_sequences: Optional[List[str]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
truncate_input_tokens: Optional[int] = None,
include_stop_sequences: Optional[bool] = None,
return_options: Optional[dict] = None,
random_seed: Optional[int] = None,
moderations: Optional[dict] = None,
stream: Optional[bool] = None,
**kwargs,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def is_watsonx_text_param(self, param: str) -> bool:
"""
Determine if user passed in a watsonx.ai text generation param
"""
text_generation_params = [
"decoding_method",
"max_new_tokens",
"min_new_tokens",
"length_penalty",
"stop_sequences",
"top_k",
"repetition_penalty",
"truncate_input_tokens",
"include_stop_sequences",
"return_options",
"random_seed",
"moderations",
"decoding_method",
"min_tokens",
]
return param in text_generation_params
def get_supported_openai_params(self):
return [
"temperature", # equivalent to temperature
"max_tokens", # equivalent to max_new_tokens
"top_p", # equivalent to top_p
"frequency_penalty", # equivalent to repetition_penalty
"stop", # equivalent to stop_sequences
"seed", # equivalent to random_seed
"stream", # equivalent to stream
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
extra_body = {}
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_new_tokens"] = v
elif k == "stream":
optional_params["stream"] = v
elif k == "temperature":
optional_params["temperature"] = v
elif k == "top_p":
optional_params["top_p"] = v
elif k == "frequency_penalty":
optional_params["repetition_penalty"] = v
elif k == "seed":
optional_params["random_seed"] = v
elif k == "stop":
optional_params["stop_sequences"] = v
elif k == "decoding_method":
extra_body["decoding_method"] = v
elif k == "min_tokens":
extra_body["min_new_tokens"] = v
elif k == "top_k":
extra_body["top_k"] = v
elif k == "truncate_input_tokens":
extra_body["truncate_input_tokens"] = v
elif k == "length_penalty":
extra_body["length_penalty"] = v
elif k == "time_limit":
extra_body["time_limit"] = v
elif k == "return_options":
extra_body["return_options"] = v
if extra_body:
optional_params["extra_body"] = extra_body
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {
"project": "watsonx_project",
"region_name": "watsonx_region_name",
"token": "watsonx_token",
}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"eu-de",
"eu-gb",
]
def get_us_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"us-south",
]
from .transformation import IBMWatsonXAIConfig
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
@ -281,6 +79,7 @@ class IBMWatsonXAI(BaseLLM):
def _prepare_text_generation_req(
self,
model_id: str,
messages: List[AllMessageValues],
prompt: str,
stream: bool,
optional_params: dict,
@ -293,11 +92,13 @@ class IBMWatsonXAI(BaseLLM):
# build auth headers
api_token = api_params.get("token")
self.token = api_token
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
headers = IBMWatsonXAIConfig().validate_environment(
headers={},
model=model_id,
messages=messages,
optional_params=optional_params,
api_key=api_token,
)
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
# init the payload to the text generation call
@ -313,7 +114,6 @@ class IBMWatsonXAI(BaseLLM):
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
url=api_params["url"],
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model_id.split("/")[1:])
@ -466,6 +266,7 @@ class IBMWatsonXAI(BaseLLM):
req_params = self._prepare_text_generation_req(
model_id=model,
prompt=prompt,
messages=messages,
stream=stream,
optional_params=optional_params,
print_verbose=print_verbose,

View file

@ -0,0 +1,299 @@
import asyncio
import json # noqa: E401
import time
import types
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Callable,
ContextManager,
Dict,
Generator,
Iterator,
List,
Optional,
Union,
)
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM
from ...base_llm.transformation import BaseConfig
from ...prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class IBMWatsonXAIConfig(BaseConfig):
"""
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
Supported params for all available watsonx.ai foundational models.
- `decoding_method` (str): One of "greedy" or "sample"
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
- `max_new_tokens` (integer): Maximum length of the generated tokens.
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
- `stop_sequences` (string[]): list of strings to use as stop sequences.
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
- `repetition_penalty` (float): token repetition penalty during text generation.
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
- `random_seed` (integer): Random seed for text generation.
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
- `stream` (bool): If True, the model will return a stream of responses.
"""
decoding_method: Optional[str] = "sample"
temperature: Optional[float] = None
max_new_tokens: Optional[int] = None # litellm.max_tokens
min_new_tokens: Optional[int] = None
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False
return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None
stream: Optional[bool] = False
def __init__(
self,
decoding_method: Optional[str] = None,
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
length_penalty: Optional[dict] = None,
stop_sequences: Optional[List[str]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
truncate_input_tokens: Optional[int] = None,
include_stop_sequences: Optional[bool] = None,
return_options: Optional[dict] = None,
random_seed: Optional[int] = None,
moderations: Optional[dict] = None,
stream: Optional[bool] = None,
**kwargs,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def is_watsonx_text_param(self, param: str) -> bool:
"""
Determine if user passed in a watsonx.ai text generation param
"""
text_generation_params = [
"decoding_method",
"max_new_tokens",
"min_new_tokens",
"length_penalty",
"stop_sequences",
"top_k",
"repetition_penalty",
"truncate_input_tokens",
"include_stop_sequences",
"return_options",
"random_seed",
"moderations",
"decoding_method",
"min_tokens",
]
return param in text_generation_params
def get_supported_openai_params(self, model: str):
return [
"temperature", # equivalent to temperature
"max_tokens", # equivalent to max_new_tokens
"top_p", # equivalent to top_p
"frequency_penalty", # equivalent to repetition_penalty
"stop", # equivalent to stop_sequences
"seed", # equivalent to random_seed
"stream", # equivalent to stream
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
extra_body = {}
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_new_tokens"] = v
elif k == "stream":
optional_params["stream"] = v
elif k == "temperature":
optional_params["temperature"] = v
elif k == "top_p":
optional_params["top_p"] = v
elif k == "frequency_penalty":
optional_params["repetition_penalty"] = v
elif k == "seed":
optional_params["random_seed"] = v
elif k == "stop":
optional_params["stop_sequences"] = v
elif k == "decoding_method":
extra_body["decoding_method"] = v
elif k == "min_tokens":
extra_body["min_new_tokens"] = v
elif k == "top_k":
extra_body["top_k"] = v
elif k == "truncate_input_tokens":
extra_body["truncate_input_tokens"] = v
elif k == "length_penalty":
extra_body["length_penalty"] = v
elif k == "time_limit":
extra_body["time_limit"] = v
elif k == "return_options":
extra_body["return_options"] = v
if extra_body:
optional_params["extra_body"] = extra_body
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {
"project": "watsonx_project",
"region_name": "watsonx_region_name",
"token": "watsonx_token",
}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"eu-de",
"eu-gb",
]
def get_us_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"us-south",
]
def _transform_messages(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
return messages
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return WatsonXAIError(
status_code=status_code, message=error_message, headers=headers
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
headers: Dict,
) -> Dict:
raise NotImplementedError(
"transform_request not implemented. Done in watsonx/completion handler.py"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"transform_response not implemented. Done in watsonx/completion handler.py"
)
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers

View file

@ -3515,6 +3515,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "volcengine":
supported_params = get_supported_openai_params(
@ -3658,6 +3663,12 @@ def get_optional_params( # noqa: PLR0915
optional_params = litellm.IBMWatsonXAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "openai":
supported_params = get_supported_openai_params(
@ -6284,7 +6295,14 @@ class ProviderConfigManager:
return litellm.VertexAIAnthropicConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif litellm.LlmProviders.FRIENDLIAI == provider:
return litellm.FriendliaiChatConfig()
elif litellm.LlmProviders.WATSONX == provider:
return litellm.IBMWatsonXChatConfig()
elif litellm.LlmProviders.WATSONX_TEXT == provider:
return litellm.IBMWatsonXAIConfig()
return litellm.OpenAIGPTConfig()

View file

@ -7,7 +7,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
from base_llm_unit_tests import BaseLLMChatTest
fireworks = FireworksAIConfig()
@ -15,21 +15,27 @@ fireworks = FireworksAIConfig()
def test_map_openai_params_tool_choice():
# Test case 1: tool_choice is "required"
result = fireworks.map_openai_params({"tool_choice": "required"}, {}, "some_model")
result = fireworks.map_openai_params(
{"tool_choice": "required"}, {}, "some_model", drop_params=False
)
assert result == {"tool_choice": "any"}
# Test case 2: tool_choice is "auto"
result = fireworks.map_openai_params({"tool_choice": "auto"}, {}, "some_model")
result = fireworks.map_openai_params(
{"tool_choice": "auto"}, {}, "some_model", drop_params=False
)
assert result == {"tool_choice": "auto"}
# Test case 3: tool_choice is not present
result = fireworks.map_openai_params(
{"some_other_param": "value"}, {}, "some_model"
{"some_other_param": "value"}, {}, "some_model", drop_params=False
)
assert result == {}
# Test case 4: tool_choice is None
result = fireworks.map_openai_params({"tool_choice": None}, {}, "some_model")
result = fireworks.map_openai_params(
{"tool_choice": None}, {}, "some_model", drop_params=False
)
assert result == {"tool_choice": None}
@ -55,7 +61,7 @@ def test_map_response_format():
},
}
result = fireworks.map_openai_params(
{"response_format": response_format}, {}, "some_model"
{"response_format": response_format}, {}, "some_model", drop_params=False
)
assert result == {
"response_format": {

View file

@ -154,13 +154,18 @@ def test_all_model_configs():
{"max_completion_tokens": 10}, {}, "llama3", drop_params=False
) == {"max_tokens": 10}
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
from litellm.llms.fireworks_ai.chat.transformation import (
FireworksAIConfig,
)
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params(
model="llama3"
)
assert FireworksAIConfig().map_openai_params(
{"max_completion_tokens": 10}, {}, "llama3"
model="llama3",
non_default_params={"max_completion_tokens": 10},
optional_params={},
drop_params=False,
) == {"max_tokens": 10}
from litellm.llms.huggingface_restapi import HuggingfaceConfig