mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
292 lines
10 KiB
Python
292 lines
10 KiB
Python
from typing import Dict, List, Optional, Union, cast
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm import verbose_logger
|
|
from litellm.caching import InMemoryCache
|
|
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
|
|
|
|
|
|
class WatsonXAIError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
|
):
|
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
|
|
|
|
iam_token_cache = InMemoryCache()
|
|
|
|
|
|
def get_watsonx_iam_url():
|
|
return (
|
|
get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
|
|
)
|
|
|
|
|
|
def generate_iam_token(api_key=None, **params) -> str:
|
|
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
|
|
|
|
if result is None:
|
|
headers = {}
|
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
|
if api_key is None:
|
|
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") or get_secret_str("WATSONX_APIKEY")
|
|
if api_key is None:
|
|
raise ValueError("API key is required")
|
|
headers["Accept"] = "application/json"
|
|
data = {
|
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
|
"apikey": api_key,
|
|
}
|
|
iam_token_url = get_watsonx_iam_url()
|
|
verbose_logger.debug(
|
|
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
|
|
iam_token_url,
|
|
headers,
|
|
data,
|
|
)
|
|
response = litellm.module_level_client.post(
|
|
url=iam_token_url, data=data, headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
json_data = response.json()
|
|
|
|
result = json_data["access_token"]
|
|
iam_token_cache.set_cache(
|
|
key=api_key,
|
|
value=result,
|
|
ttl=json_data["expires_in"] - 10, # leave some buffer
|
|
)
|
|
|
|
return cast(str, result)
|
|
|
|
|
|
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
|
|
if token is not None:
|
|
return token
|
|
token = generate_iam_token(api_key)
|
|
return token
|
|
|
|
|
|
def _get_api_params(
|
|
params: dict,
|
|
) -> WatsonXAPIParams:
|
|
"""
|
|
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
|
"""
|
|
# Load auth variables from params
|
|
project_id = params.pop(
|
|
"project_id", params.pop("watsonx_project", None)
|
|
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
|
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
|
region_name = params.pop("region_name", params.pop("region", None))
|
|
if region_name is None:
|
|
region_name = params.pop(
|
|
"watsonx_region_name", params.pop("watsonx_region", None)
|
|
) # consistent with how vertex ai + aws regions are accepted
|
|
|
|
# Load auth variables from environment variables
|
|
if project_id is None:
|
|
project_id = (
|
|
get_secret_str("WATSONX_PROJECT_ID")
|
|
or get_secret_str("WX_PROJECT_ID")
|
|
or get_secret_str("PROJECT_ID")
|
|
)
|
|
if region_name is None:
|
|
region_name = (
|
|
get_secret_str("WATSONX_REGION")
|
|
or get_secret_str("WX_REGION")
|
|
or get_secret_str("REGION")
|
|
)
|
|
if space_id is None:
|
|
space_id = (
|
|
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
|
or get_secret_str("WATSONX_SPACE_ID")
|
|
or get_secret_str("WX_SPACE_ID")
|
|
or get_secret_str("SPACE_ID")
|
|
)
|
|
|
|
if project_id is None:
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
|
)
|
|
|
|
return WatsonXAPIParams(
|
|
project_id=project_id,
|
|
space_id=space_id,
|
|
region_name=region_name,
|
|
)
|
|
|
|
|
|
def convert_watsonx_messages_to_prompt(
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
provider: str,
|
|
custom_prompt_dict: Dict,
|
|
) -> str:
|
|
# handle anthropic prompts and amazon titan prompts
|
|
if model in custom_prompt_dict:
|
|
# check if the model has a registered custom prompt
|
|
model_prompt_dict = custom_prompt_dict[model]
|
|
prompt = ptf.custom_prompt(
|
|
messages=messages,
|
|
role_dict=model_prompt_dict.get(
|
|
"role_dict", model_prompt_dict.get("roles")
|
|
),
|
|
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
|
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
|
bos_token=model_prompt_dict.get("bos_token", ""),
|
|
eos_token=model_prompt_dict.get("eos_token", ""),
|
|
)
|
|
return prompt
|
|
elif provider == "ibm-mistralai":
|
|
prompt = ptf.mistral_instruct_pt(messages=messages)
|
|
else:
|
|
prompt: str = ptf.prompt_factory( # type: ignore
|
|
model=model, messages=messages, custom_llm_provider="watsonx"
|
|
)
|
|
return prompt
|
|
|
|
|
|
# Mixin class for shared IBM Watson X functionality
|
|
class IBMWatsonXMixin:
|
|
def validate_environment(
|
|
self,
|
|
headers: Dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: Dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> Dict:
|
|
default_headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
}
|
|
|
|
if "Authorization" in headers:
|
|
return {**default_headers, **headers}
|
|
token = cast(
|
|
Optional[str],
|
|
optional_params.get("token") or get_secret_str("WATSONX_TOKEN"),
|
|
)
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"):
|
|
headers["Authorization"] = f"ZenApiKey {zen_api_key}"
|
|
else:
|
|
token = _generate_watsonx_token(api_key=api_key, token=token)
|
|
# build auth headers
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return {**default_headers, **headers}
|
|
|
|
def _get_base_url(self, api_base: Optional[str]) -> str:
|
|
url = (
|
|
api_base
|
|
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
|
or get_secret_str("WATSONX_URL")
|
|
or get_secret_str("WX_URL")
|
|
or get_secret_str("WML_URL")
|
|
)
|
|
|
|
if url is None:
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
|
)
|
|
return url
|
|
|
|
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
|
|
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
|
|
url = url + f"?version={api_version}"
|
|
|
|
return url
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return WatsonXAIError(
|
|
status_code=status_code, message=error_message, headers=headers
|
|
)
|
|
|
|
@staticmethod
|
|
def get_watsonx_credentials(
|
|
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
|
|
) -> WatsonXCredentials:
|
|
api_key = (
|
|
api_key
|
|
or optional_params.pop("apikey", None)
|
|
or get_secret_str("WATSONX_APIKEY")
|
|
or get_secret_str("WATSONX_API_KEY")
|
|
or get_secret_str("WX_API_KEY")
|
|
)
|
|
|
|
api_base = (
|
|
api_base
|
|
or optional_params.pop(
|
|
"url",
|
|
optional_params.pop("api_base", optional_params.pop("base_url", None)),
|
|
)
|
|
or get_secret_str("WATSONX_API_BASE")
|
|
or get_secret_str("WATSONX_URL")
|
|
or get_secret_str("WX_URL")
|
|
or get_secret_str("WML_URL")
|
|
)
|
|
|
|
wx_credentials = optional_params.pop(
|
|
"wx_credentials",
|
|
optional_params.pop(
|
|
"watsonx_credentials", None
|
|
), # follow {provider}_credentials, same as vertex ai
|
|
)
|
|
|
|
token: Optional[str] = None
|
|
|
|
if wx_credentials is not None:
|
|
api_base = wx_credentials.get("url", api_base)
|
|
api_key = wx_credentials.get(
|
|
"apikey", wx_credentials.get("api_key", api_key)
|
|
)
|
|
token = wx_credentials.get(
|
|
"token",
|
|
wx_credentials.get(
|
|
"watsonx_token", None
|
|
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
|
)
|
|
if api_key is None or not isinstance(api_key, str):
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
|
|
)
|
|
if api_base is None or not isinstance(api_base, str):
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
|
)
|
|
return WatsonXCredentials(
|
|
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
|
)
|
|
|
|
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
|
payload: dict = {}
|
|
if model.startswith("deployment/"):
|
|
if api_params["space_id"] is None:
|
|
raise WatsonXAIError(
|
|
status_code=401,
|
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
)
|
|
payload["space_id"] = api_params["space_id"]
|
|
return payload
|
|
payload["model_id"] = model
|
|
payload["project_id"] = api_params["project_id"]
|
|
return payload
|