mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(fireworks_ai/): inherit from openai like base config (#7146)
* refactor(fireworks_ai/): inherit from openai like base config refactors fireworks ai to use a common config * test: fix import in test * refactor(watsonx/): refactor watsonx to use llm base config refactors chat + completion routes to base config path * fix: fix linting error * test: fix test * fix: fix test
This commit is contained in:
parent
2fb2801eb4
commit
311432ca17
15 changed files with 449 additions and 307 deletions
|
@ -1163,10 +1163,11 @@ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
|
|||
from .llms.cerebras.chat import CerebrasConfig
|
||||
from .llms.sambanova.chat import SambanovaConfig
|
||||
from .llms.ai21.chat import AI21ChatConfig
|
||||
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||
FireworksAIEmbeddingConfig,
|
||||
)
|
||||
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
|
||||
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
||||
from .llms.xai.chat.transformation import XAIChatConfig
|
||||
from .llms.volcengine import VolcEngineConfig
|
||||
|
@ -1183,7 +1184,7 @@ from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
|||
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
|
|
|
@ -3,55 +3,54 @@ DEFAULT_BATCH_SIZE = 512
|
|||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||
DEFAULT_MAX_RETRIES = 2
|
||||
LITELLM_CHAT_PROVIDERS = [
|
||||
"openai",
|
||||
"openai_like",
|
||||
"xai",
|
||||
"custom_openai",
|
||||
"text-completion-openai",
|
||||
"cohere",
|
||||
"cohere_chat",
|
||||
"clarifai",
|
||||
"anthropic",
|
||||
"anthropic_text",
|
||||
"replicate",
|
||||
"huggingface",
|
||||
"together_ai",
|
||||
"openrouter",
|
||||
"vertex_ai",
|
||||
"vertex_ai_beta",
|
||||
"palm",
|
||||
"gemini",
|
||||
"ai21",
|
||||
"baseten",
|
||||
"azure",
|
||||
"azure_text",
|
||||
"azure_ai",
|
||||
"sagemaker",
|
||||
"sagemaker_chat",
|
||||
"bedrock",
|
||||
"vllm",
|
||||
"nlp_cloud",
|
||||
"petals",
|
||||
"oobabooga",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"deepinfra",
|
||||
"perplexity",
|
||||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"nvidia_nim",
|
||||
"cerebras",
|
||||
"ai21_chat",
|
||||
"volcengine",
|
||||
"codestral",
|
||||
"text-completion-codestral",
|
||||
"deepseek",
|
||||
"sambanova",
|
||||
"maritalk",
|
||||
"voyage",
|
||||
"cloudflare",
|
||||
"xinference",
|
||||
# "openai",
|
||||
# "openai_like",
|
||||
# "xai",
|
||||
# "custom_openai",
|
||||
# "text-completion-openai",
|
||||
# "cohere",
|
||||
# "cohere_chat",
|
||||
# "clarifai",
|
||||
# "anthropic",
|
||||
# "anthropic_text",
|
||||
# "replicate",
|
||||
# "huggingface",
|
||||
# "together_ai",
|
||||
# "openrouter",
|
||||
# "vertex_ai",
|
||||
# "vertex_ai_beta",
|
||||
# "palm",
|
||||
# "gemini",
|
||||
# "ai21",
|
||||
# "baseten",
|
||||
# "azure",
|
||||
# "azure_text",
|
||||
# "azure_ai",
|
||||
# "sagemaker",
|
||||
# "sagemaker_chat",
|
||||
# "bedrock",
|
||||
# "vllm",
|
||||
# "nlp_cloud",
|
||||
# "petals",
|
||||
# "oobabooga",
|
||||
# "ollama",
|
||||
# "ollama_chat",
|
||||
# "deepinfra",
|
||||
# "perplexity",
|
||||
# "anyscale",
|
||||
# "mistral",
|
||||
# "groq",
|
||||
# "nvidia_nim",
|
||||
# "cerebras",
|
||||
# "ai21_chat",
|
||||
# "volcengine",
|
||||
# "codestral",
|
||||
# "text-completion-codestral",
|
||||
# "deepseek",
|
||||
# "sambanova",
|
||||
# "maritalk",
|
||||
# "voyage",
|
||||
# "cloudflare",
|
||||
"fireworks_ai",
|
||||
"friendliai",
|
||||
"watsonx",
|
||||
|
|
|
@ -495,7 +495,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
||||
model, api_base, api_key
|
||||
model=model, api_base=api_base, api_key=api_key
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
(
|
||||
|
|
|
@ -40,7 +40,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import (
|
|||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -33,7 +34,7 @@ class BaseLLMException(Exception):
|
|||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
|
|
|
@ -3,10 +3,11 @@ from typing import Literal, Optional, Tuple, Union
|
|||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig
|
||||
|
||||
|
||||
class FireworksAIConfig:
|
||||
class FireworksAIConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
||||
|
||||
|
@ -56,23 +57,9 @@ class FireworksAIConfig:
|
|||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"stream",
|
||||
"tools",
|
||||
|
@ -98,8 +85,10 @@ class FireworksAIConfig:
|
|||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params()
|
||||
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
if value == "required":
|
24
litellm/llms/friendliai/chat/transformation.py
Normal file
24
litellm/llms/friendliai/chat/transformation.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Friendliai's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
import json
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class FriendliaiChatConfig(OpenAILikeChatConfig):
|
||||
pass
|
|
@ -19,7 +19,10 @@ from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
|||
|
||||
class OpenAILikeChatConfig(OpenAIGPTConfig):
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
|
|
|
@ -21,7 +21,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
|||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=api_params["url"],
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
from typing import Callable, Optional, cast
|
||||
from typing import Callable, Dict, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
self.request = httpx.Request(method="POST", url=url)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
class WatsonXAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
|
||||
|
||||
iam_token_cache = InMemoryCache()
|
||||
|
@ -151,13 +150,11 @@ def _get_api_params(
|
|||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
|
|
|
@ -29,216 +29,14 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates import factory as ptf
|
||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig:
|
||||
"""
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||
|
||||
Supported params for all available watsonx.ai foundational models.
|
||||
|
||||
- `decoding_method` (str): One of "greedy" or "sample"
|
||||
|
||||
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||
|
||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||
|
||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||
|
||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||
|
||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||
|
||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||
|
||||
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||
|
||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||
|
||||
- `random_seed` (integer): Random seed for text generation.
|
||||
|
||||
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||
|
||||
- `stream` (bool): If True, the model will return a stream of responses.
|
||||
"""
|
||||
|
||||
decoding_method: Optional[str] = "sample"
|
||||
temperature: Optional[float] = None
|
||||
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||
min_new_tokens: Optional[int] = None
|
||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
truncate_input_tokens: Optional[int] = None
|
||||
include_stop_sequences: Optional[bool] = False
|
||||
return_options: Optional[Dict[str, bool]] = None
|
||||
random_seed: Optional[int] = None # e.g 42
|
||||
moderations: Optional[dict] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoding_method: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
length_penalty: Optional[dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
truncate_input_tokens: Optional[int] = None,
|
||||
include_stop_sequences: Optional[bool] = None,
|
||||
return_options: Optional[dict] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
moderations: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def is_watsonx_text_param(self, param: str) -> bool:
|
||||
"""
|
||||
Determine if user passed in a watsonx.ai text generation param
|
||||
"""
|
||||
text_generation_params = [
|
||||
"decoding_method",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"length_penalty",
|
||||
"stop_sequences",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"truncate_input_tokens",
|
||||
"include_stop_sequences",
|
||||
"return_options",
|
||||
"random_seed",
|
||||
"moderations",
|
||||
"decoding_method",
|
||||
"min_tokens",
|
||||
]
|
||||
|
||||
return param in text_generation_params
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
extra_body = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_new_tokens"] = v
|
||||
elif k == "stream":
|
||||
optional_params["stream"] = v
|
||||
elif k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
elif k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
elif k == "frequency_penalty":
|
||||
optional_params["repetition_penalty"] = v
|
||||
elif k == "seed":
|
||||
optional_params["random_seed"] = v
|
||||
elif k == "stop":
|
||||
optional_params["stop_sequences"] = v
|
||||
elif k == "decoding_method":
|
||||
extra_body["decoding_method"] = v
|
||||
elif k == "min_tokens":
|
||||
extra_body["min_new_tokens"] = v
|
||||
elif k == "top_k":
|
||||
extra_body["top_k"] = v
|
||||
elif k == "truncate_input_tokens":
|
||||
extra_body["truncate_input_tokens"] = v
|
||||
elif k == "length_penalty":
|
||||
extra_body["length_penalty"] = v
|
||||
elif k == "time_limit":
|
||||
extra_body["time_limit"] = v
|
||||
elif k == "return_options":
|
||||
extra_body["return_options"] = v
|
||||
|
||||
if extra_body:
|
||||
optional_params["extra_body"] = extra_body
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {
|
||||
"project": "watsonx_project",
|
||||
"region_name": "watsonx_region_name",
|
||||
"token": "watsonx_token",
|
||||
}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"us-south",
|
||||
]
|
||||
from .transformation import IBMWatsonXAIConfig
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
||||
|
@ -281,6 +79,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
def _prepare_text_generation_req(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[AllMessageValues],
|
||||
prompt: str,
|
||||
stream: bool,
|
||||
optional_params: dict,
|
||||
|
@ -293,11 +92,13 @@ class IBMWatsonXAI(BaseLLM):
|
|||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
headers = IBMWatsonXAIConfig().validate_environment(
|
||||
headers={},
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_token,
|
||||
)
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
# init the payload to the text generation call
|
||||
|
@ -313,7 +114,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=api_params["url"],
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model_id.split("/")[1:])
|
||||
|
@ -466,6 +266,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
req_params = self._prepare_text_generation_req(
|
||||
model_id=model,
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
optional_params=optional_params,
|
||||
print_verbose=print_verbose,
|
||||
|
|
299
litellm/llms/watsonx/completion/transformation.py
Normal file
299
litellm/llms/watsonx/completion/transformation.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
import asyncio
|
||||
import json # noqa: E401
|
||||
import time
|
||||
import types
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ...base_llm.transformation import BaseConfig
|
||||
from ...prompt_templates import factory as ptf
|
||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||
|
||||
Supported params for all available watsonx.ai foundational models.
|
||||
|
||||
- `decoding_method` (str): One of "greedy" or "sample"
|
||||
|
||||
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||
|
||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||
|
||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||
|
||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||
|
||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||
|
||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||
|
||||
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||
|
||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||
|
||||
- `random_seed` (integer): Random seed for text generation.
|
||||
|
||||
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||
|
||||
- `stream` (bool): If True, the model will return a stream of responses.
|
||||
"""
|
||||
|
||||
decoding_method: Optional[str] = "sample"
|
||||
temperature: Optional[float] = None
|
||||
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||
min_new_tokens: Optional[int] = None
|
||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
truncate_input_tokens: Optional[int] = None
|
||||
include_stop_sequences: Optional[bool] = False
|
||||
return_options: Optional[Dict[str, bool]] = None
|
||||
random_seed: Optional[int] = None # e.g 42
|
||||
moderations: Optional[dict] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoding_method: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
length_penalty: Optional[dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
truncate_input_tokens: Optional[int] = None,
|
||||
include_stop_sequences: Optional[bool] = None,
|
||||
return_options: Optional[dict] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
moderations: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def is_watsonx_text_param(self, param: str) -> bool:
|
||||
"""
|
||||
Determine if user passed in a watsonx.ai text generation param
|
||||
"""
|
||||
text_generation_params = [
|
||||
"decoding_method",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"length_penalty",
|
||||
"stop_sequences",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"truncate_input_tokens",
|
||||
"include_stop_sequences",
|
||||
"return_options",
|
||||
"random_seed",
|
||||
"moderations",
|
||||
"decoding_method",
|
||||
"min_tokens",
|
||||
]
|
||||
|
||||
return param in text_generation_params
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: Dict,
|
||||
optional_params: Dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
extra_body = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_new_tokens"] = v
|
||||
elif k == "stream":
|
||||
optional_params["stream"] = v
|
||||
elif k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
elif k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
elif k == "frequency_penalty":
|
||||
optional_params["repetition_penalty"] = v
|
||||
elif k == "seed":
|
||||
optional_params["random_seed"] = v
|
||||
elif k == "stop":
|
||||
optional_params["stop_sequences"] = v
|
||||
elif k == "decoding_method":
|
||||
extra_body["decoding_method"] = v
|
||||
elif k == "min_tokens":
|
||||
extra_body["min_new_tokens"] = v
|
||||
elif k == "top_k":
|
||||
extra_body["top_k"] = v
|
||||
elif k == "truncate_input_tokens":
|
||||
extra_body["truncate_input_tokens"] = v
|
||||
elif k == "length_penalty":
|
||||
extra_body["length_penalty"] = v
|
||||
elif k == "time_limit":
|
||||
extra_body["time_limit"] = v
|
||||
elif k == "return_options":
|
||||
extra_body["return_options"] = v
|
||||
|
||||
if extra_body:
|
||||
optional_params["extra_body"] = extra_body
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {
|
||||
"project": "watsonx_project",
|
||||
"region_name": "watsonx_region_name",
|
||||
"token": "watsonx_token",
|
||||
}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"us-south",
|
||||
]
|
||||
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
return messages
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return WatsonXAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: Dict,
|
||||
) -> Dict:
|
||||
raise NotImplementedError(
|
||||
"transform_request not implemented. Done in watsonx/completion handler.py"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"transform_response not implemented. Done in watsonx/completion handler.py"
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
|
@ -3515,6 +3515,11 @@ def get_optional_params( # noqa: PLR0915
|
|||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -3658,6 +3663,12 @@ def get_optional_params( # noqa: PLR0915
|
|||
optional_params = litellm.IBMWatsonXAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6284,7 +6295,14 @@ class ProviderConfigManager:
|
|||
return litellm.VertexAIAnthropicConfig()
|
||||
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||
return litellm.CloudflareChatConfig()
|
||||
|
||||
elif litellm.LlmProviders.FIREWORKS_AI == provider:
|
||||
return litellm.FireworksAIConfig()
|
||||
elif litellm.LlmProviders.FRIENDLIAI == provider:
|
||||
return litellm.FriendliaiChatConfig()
|
||||
elif litellm.LlmProviders.WATSONX == provider:
|
||||
return litellm.IBMWatsonXChatConfig()
|
||||
elif litellm.LlmProviders.WATSONX_TEXT == provider:
|
||||
return litellm.IBMWatsonXAIConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
fireworks = FireworksAIConfig()
|
||||
|
@ -15,21 +15,27 @@ fireworks = FireworksAIConfig()
|
|||
|
||||
def test_map_openai_params_tool_choice():
|
||||
# Test case 1: tool_choice is "required"
|
||||
result = fireworks.map_openai_params({"tool_choice": "required"}, {}, "some_model")
|
||||
result = fireworks.map_openai_params(
|
||||
{"tool_choice": "required"}, {}, "some_model", drop_params=False
|
||||
)
|
||||
assert result == {"tool_choice": "any"}
|
||||
|
||||
# Test case 2: tool_choice is "auto"
|
||||
result = fireworks.map_openai_params({"tool_choice": "auto"}, {}, "some_model")
|
||||
result = fireworks.map_openai_params(
|
||||
{"tool_choice": "auto"}, {}, "some_model", drop_params=False
|
||||
)
|
||||
assert result == {"tool_choice": "auto"}
|
||||
|
||||
# Test case 3: tool_choice is not present
|
||||
result = fireworks.map_openai_params(
|
||||
{"some_other_param": "value"}, {}, "some_model"
|
||||
{"some_other_param": "value"}, {}, "some_model", drop_params=False
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
# Test case 4: tool_choice is None
|
||||
result = fireworks.map_openai_params({"tool_choice": None}, {}, "some_model")
|
||||
result = fireworks.map_openai_params(
|
||||
{"tool_choice": None}, {}, "some_model", drop_params=False
|
||||
)
|
||||
assert result == {"tool_choice": None}
|
||||
|
||||
|
||||
|
@ -55,7 +61,7 @@ def test_map_response_format():
|
|||
},
|
||||
}
|
||||
result = fireworks.map_openai_params(
|
||||
{"response_format": response_format}, {}, "some_model"
|
||||
{"response_format": response_format}, {}, "some_model", drop_params=False
|
||||
)
|
||||
assert result == {
|
||||
"response_format": {
|
||||
|
|
|
@ -154,13 +154,18 @@ def test_all_model_configs():
|
|||
{"max_completion_tokens": 10}, {}, "llama3", drop_params=False
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
|
||||
from litellm.llms.fireworks_ai.chat.transformation import (
|
||||
FireworksAIConfig,
|
||||
)
|
||||
|
||||
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
|
||||
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params(
|
||||
model="llama3"
|
||||
)
|
||||
assert FireworksAIConfig().map_openai_params(
|
||||
{"max_completion_tokens": 10}, {}, "llama3"
|
||||
model="llama3",
|
||||
non_default_params={"max_completion_tokens": 10},
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
from litellm.llms.huggingface_restapi import HuggingfaceConfig
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue