mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
* test(test_watsonx.py): e2e unit test for watsonx custom header covers https://github.com/BerriAI/litellm/issues/7408 * fix(common_utils.py): handle auth token already present in headers (watsonx + openai-like base handler) Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(watsonx/chat): fix chat route Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(huggingface/chat/handler.py): fix huggingface async completion calls * Correct handling of max_retries=0 to disable AzureOpenAI retries (#7379) * test: fix test --------- Co-authored-by: Minh Duc <phamminhduc0711@gmail.com>
123 lines
4.6 KiB
Python
123 lines
4.6 KiB
Python
"""
|
|
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
|
|
|
|
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
|
"""
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
|
|
|
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
from ..common_utils import IBMWatsonXMixin, WatsonXAIError
|
|
|
|
|
|
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
|
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
return [
|
|
"temperature", # equivalent to temperature
|
|
"max_tokens", # equivalent to max_new_tokens
|
|
"top_p", # equivalent to top_p
|
|
"frequency_penalty", # equivalent to repetition_penalty
|
|
"stop", # equivalent to stop_sequences
|
|
"seed", # equivalent to random_seed
|
|
"stream", # equivalent to stream
|
|
"tools",
|
|
"tool_choice", # equivalent to tool_choice + tool_choice_options
|
|
"logprobs",
|
|
"top_logprobs",
|
|
"n",
|
|
"presence_penalty",
|
|
"response_format",
|
|
]
|
|
|
|
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
|
|
if tool_choice is None:
|
|
return False
|
|
if isinstance(tool_choice, str):
|
|
return tool_choice in ["auto", "none", "required"]
|
|
return False
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
## TOOLS ##
|
|
_tools = non_default_params.pop("tools", None)
|
|
if _tools is not None:
|
|
# remove 'additionalProperties' from tools
|
|
_tools = _remove_additional_properties(_tools)
|
|
# remove 'strict' from tools
|
|
_tools = _remove_strict_from_schema(_tools)
|
|
if _tools is not None:
|
|
non_default_params["tools"] = _tools
|
|
|
|
## TOOL CHOICE ##
|
|
|
|
_tool_choice = non_default_params.pop("tool_choice", None)
|
|
if self.is_tool_choice_option(_tool_choice):
|
|
optional_params["tool_choice_options"] = _tool_choice
|
|
elif _tool_choice is not None:
|
|
optional_params["tool_choice"] = _tool_choice
|
|
return super().map_openai_params(
|
|
non_default_params, optional_params, model, drop_params
|
|
)
|
|
|
|
def _get_openai_compatible_provider_info(
|
|
self, api_base: Optional[str], api_key: Optional[str]
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
|
dynamic_api_key = (
|
|
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
|
) # vllm does not require an api key
|
|
return api_base, dynamic_api_key
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: str,
|
|
model: str,
|
|
optional_params: dict,
|
|
stream: Optional[bool] = None,
|
|
) -> str:
|
|
url = self._get_base_url(api_base=api_base)
|
|
if model.startswith("deployment/"):
|
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
|
if optional_params.get("space_id") is None:
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
)
|
|
deployment_id = "/".join(model.split("/")[1:])
|
|
endpoint = (
|
|
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
|
if stream
|
|
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
|
)
|
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
|
else:
|
|
endpoint = (
|
|
WatsonXAIEndpoint.CHAT_STREAM.value
|
|
if stream
|
|
else WatsonXAIEndpoint.CHAT.value
|
|
)
|
|
url = url.rstrip("/") + endpoint
|
|
|
|
## add api version
|
|
url = self._add_api_version_to_url(
|
|
url=url, api_version=optional_params.pop("api_version", None)
|
|
)
|
|
return url
|
|
|
|
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
|
payload: dict = {}
|
|
if model.startswith("deployment/"):
|
|
return payload
|
|
payload["model_id"] = model
|
|
payload["project_id"] = api_params["project_id"]
|
|
return payload
|