from typing import Dict, List, Optional, Union, cast import httpx import litellm from litellm import verbose_logger from litellm.caching import InMemoryCache from litellm.litellm_core_utils.prompt_templates import factory as ptf from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials class WatsonXAIError(BaseLLMException): def __init__( self, status_code: int, message: str, headers: Optional[Union[Dict, httpx.Headers]] = None, ): super().__init__(status_code=status_code, message=message, headers=headers) iam_token_cache = InMemoryCache() def get_watsonx_iam_url(): return ( get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token" ) def generate_iam_token(api_key=None, **params) -> str: result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore if result is None: headers = {} headers["Content-Type"] = "application/x-www-form-urlencoded" if api_key is None: api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") or get_secret_str("WATSONX_APIKEY") if api_key is None: raise ValueError("API key is required") headers["Accept"] = "application/json" data = { "grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": api_key, } iam_token_url = get_watsonx_iam_url() verbose_logger.debug( "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s", iam_token_url, headers, data, ) response = litellm.module_level_client.post( url=iam_token_url, data=data, headers=headers ) response.raise_for_status() json_data = response.json() result = json_data["access_token"] iam_token_cache.set_cache( key=api_key, value=result, ttl=json_data["expires_in"] - 10, # leave some buffer ) return cast(str, result) def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str: if token is not None: return token token = generate_iam_token(api_key) return token def _get_api_params( params: dict, ) -> WatsonXAPIParams: """ Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. """ # Load auth variables from params project_id = params.pop( "project_id", params.pop("watsonx_project", None) ) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params space_id = params.pop("space_id", None) # watsonx.ai deployment space_id region_name = params.pop("region_name", params.pop("region", None)) if region_name is None: region_name = params.pop( "watsonx_region_name", params.pop("watsonx_region", None) ) # consistent with how vertex ai + aws regions are accepted # Load auth variables from environment variables if project_id is None: project_id = ( get_secret_str("WATSONX_PROJECT_ID") or get_secret_str("WX_PROJECT_ID") or get_secret_str("PROJECT_ID") ) if region_name is None: region_name = ( get_secret_str("WATSONX_REGION") or get_secret_str("WX_REGION") or get_secret_str("REGION") ) if space_id is None: space_id = ( get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID") or get_secret_str("WATSONX_SPACE_ID") or get_secret_str("WX_SPACE_ID") or get_secret_str("SPACE_ID") ) if project_id is None and space_id is None: raise WatsonXAIError( status_code=401, message="Error: At least one of project_id or space_id must be set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.", ) return WatsonXAPIParams( project_id=project_id, space_id=space_id, region_name=region_name, ) def convert_watsonx_messages_to_prompt( model: str, messages: List[AllMessageValues], provider: str, custom_prompt_dict: Dict, ) -> str: # handle anthropic prompts and amazon titan prompts if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_dict = custom_prompt_dict[model] prompt = ptf.custom_prompt( messages=messages, role_dict=model_prompt_dict.get( "role_dict", model_prompt_dict.get("roles") ), initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), bos_token=model_prompt_dict.get("bos_token", ""), eos_token=model_prompt_dict.get("eos_token", ""), ) return prompt elif provider == "ibm-mistralai": prompt = ptf.mistral_instruct_pt(messages=messages) else: prompt: str = ptf.prompt_factory( # type: ignore model=model, messages=messages, custom_llm_provider="watsonx" ) return prompt # Mixin class for shared IBM Watson X functionality class IBMWatsonXMixin: def validate_environment( self, headers: Dict, model: str, messages: List[AllMessageValues], optional_params: Dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: default_headers = { "Content-Type": "application/json", "Accept": "application/json", } if "Authorization" in headers: return {**default_headers, **headers} token = cast( Optional[str], optional_params.get("token") or get_secret_str("WATSONX_TOKEN"), ) if token: headers["Authorization"] = f"Bearer {token}" elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"): headers["Authorization"] = f"ZenApiKey {zen_api_key}" else: token = _generate_watsonx_token(api_key=api_key, token=token) # build auth headers headers["Authorization"] = f"Bearer {token}" return {**default_headers, **headers} def _get_base_url(self, api_base: Optional[str]) -> str: url = ( api_base or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' or get_secret_str("WATSONX_URL") or get_secret_str("WX_URL") or get_secret_str("WML_URL") ) if url is None: raise WatsonXAIError( status_code=401, message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", ) return url def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str: api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION url = url + f"?version={api_version}" return url def get_error_class( self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] ) -> BaseLLMException: return WatsonXAIError( status_code=status_code, message=error_message, headers=headers ) @staticmethod def get_watsonx_credentials( optional_params: dict, api_key: Optional[str], api_base: Optional[str] ) -> WatsonXCredentials: api_key = ( api_key or optional_params.pop("apikey", None) or get_secret_str("WATSONX_APIKEY") or get_secret_str("WATSONX_API_KEY") or get_secret_str("WX_API_KEY") ) api_base = ( api_base or optional_params.pop( "url", optional_params.pop("api_base", optional_params.pop("base_url", None)), ) or get_secret_str("WATSONX_API_BASE") or get_secret_str("WATSONX_URL") or get_secret_str("WX_URL") or get_secret_str("WML_URL") ) wx_credentials = optional_params.pop( "wx_credentials", optional_params.pop( "watsonx_credentials", None ), # follow {provider}_credentials, same as vertex ai ) token: Optional[str] = None if wx_credentials is not None: api_base = wx_credentials.get("url", api_base) api_key = wx_credentials.get( "apikey", wx_credentials.get("api_key", api_key) ) token = wx_credentials.get( "token", wx_credentials.get( "watsonx_token", None ), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..' ) if api_key is None or not isinstance(api_key, str): raise WatsonXAIError( status_code=401, message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.", ) if api_base is None or not isinstance(api_base, str): raise WatsonXAIError( status_code=401, message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", ) return WatsonXCredentials( api_key=api_key, api_base=api_base, token=cast(Optional[str], token) ) def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: payload: dict = {} payload["space_id"] = api_params["space_id"] if model.startswith("deployment/"): if api_params["space_id"] is None: raise WatsonXAIError( status_code=401, message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", ) return payload payload["model_id"] = model payload["project_id"] = api_params["project_id"] return payload