litellm-mirror/litellm/llms/watsonx/common_utils.py
Krish Dholakia b0f570ee16 Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
2025-01-01 18:57:29 -08:00

271 lines
9.4 KiB
Python

from typing import Dict, List, Optional, Union, cast
import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
class WatsonXAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[Dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
iam_token_cache = InMemoryCache()
def get_watsonx_iam_url():
return (
get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
)
def generate_iam_token(api_key=None, **params) -> str:
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
if result is None:
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
iam_token_url = get_watsonx_iam_url()
verbose_logger.debug(
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
iam_token_url,
headers,
data,
)
response = litellm.module_level_client.post(
url=iam_token_url, data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
result = json_data["access_token"]
iam_token_cache.set_cache(
key=api_key,
value=result,
ttl=json_data["expires_in"] - 10, # leave some buffer
)
return cast(str, result)
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
if token is not None:
return token
token = generate_iam_token(api_key)
return token
def _get_api_params(
params: dict,
) -> WatsonXAPIParams:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
# Load auth variables from environment variables
if project_id is None:
project_id = (
get_secret_str("WATSONX_PROJECT_ID")
or get_secret_str("WX_PROJECT_ID")
or get_secret_str("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret_str("WATSONX_REGION")
or get_secret_str("WX_REGION")
or get_secret_str("REGION")
)
if space_id is None:
space_id = (
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret_str("WATSONX_SPACE_ID")
or get_secret_str("WX_SPACE_ID")
or get_secret_str("SPACE_ID")
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return WatsonXAPIParams(
project_id=project_id,
space_id=space_id,
region_name=region_name,
)
def convert_watsonx_messages_to_prompt(
model: str,
messages: List[AllMessageValues],
provider: str,
custom_prompt_dict: Dict,
) -> str:
# handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_dict = custom_prompt_dict[model]
prompt = ptf.custom_prompt(
messages=messages,
role_dict=model_prompt_dict.get(
"role_dict", model_prompt_dict.get("roles")
),
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
bos_token=model_prompt_dict.get("bos_token", ""),
eos_token=model_prompt_dict.get("eos_token", ""),
)
return prompt
elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages)
else:
prompt: str = ptf.prompt_factory( # type: ignore
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt
# Mixin class for shared IBM Watson X functionality
class IBMWatsonXMixin:
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if "Authorization" in headers:
return {**default_headers, **headers}
token = cast(Optional[str], optional_params.get("token"))
if token:
headers["Authorization"] = f"Bearer {token}"
else:
token = _generate_watsonx_token(api_key=api_key, token=token)
# build auth headers
headers["Authorization"] = f"Bearer {token}"
return {**default_headers, **headers}
def _get_base_url(self, api_base: Optional[str]) -> str:
url = (
api_base
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
)
return url
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
url = url + f"?version={api_version}"
return url
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return WatsonXAIError(
status_code=status_code, message=error_message, headers=headers
)
@staticmethod
def get_watsonx_credentials(
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
) -> WatsonXCredentials:
api_key = (
api_key
or optional_params.pop("apikey", None)
or get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
api_base = (
api_base
or optional_params.pop(
"url",
optional_params.pop("api_base", optional_params.pop("base_url", None)),
)
or get_secret_str("WATSONX_API_BASE")
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
wx_credentials = optional_params.pop(
"wx_credentials",
optional_params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
token: Optional[str] = None
if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", None
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
if api_key is None or not isinstance(api_key, str):
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
)
if api_base is None or not isinstance(api_base, str):
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
)
return WatsonXCredentials(
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
)