litellm-mirror/litellm/llms/watsonx/common_utils.py
Krish Dholakia 311432ca17
refactor(fireworks_ai/): inherit from openai like base config (#7146)
* refactor(fireworks_ai/): inherit from openai like base config

refactors fireworks ai to use a common config

* test: fix import in test

* refactor(watsonx/): refactor watsonx to use llm base config

refactors chat + completion routes to base config path

* fix: fix linting error

* test: fix test

* fix: fix test
2024-12-10 16:15:19 -08:00

169 lines
6 KiB
Python

from typing import Callable, Dict, Optional, Union, cast
import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams
class WatsonXAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[Dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
iam_token_cache = InMemoryCache()
def generate_iam_token(api_key=None, **params) -> str:
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
if result is None:
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
verbose_logger.debug(
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
"https://iam.cloud.ibm.com/identity/token",
headers,
data,
)
response = httpx.post(
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
result = json_data["access_token"]
iam_token_cache.set_cache(
key=api_key,
value=result,
ttl=json_data["expires_in"] - 10, # leave some buffer
)
return cast(str, result)
def _get_api_params(
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> WatsonXAPIParams:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
# Load auth variables from environment variables
if url is None:
url = (
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if api_key is None:
api_key = (
get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if token is None:
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None:
project_id = (
get_secret_str("WATSONX_PROJECT_ID")
or get_secret_str("WX_PROJECT_ID")
or get_secret_str("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret_str("WATSONX_REGION")
or get_secret_str("WX_REGION")
or get_secret_str("REGION")
)
if space_id is None:
space_id = (
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret_str("WATSONX_SPACE_ID")
or get_secret_str("WX_SPACE_ID")
or get_secret_str("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None and generate_token:
# generate the auth token
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
token = generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return WatsonXAPIParams(
url=url,
api_key=api_key,
token=cast(str, token),
project_id=project_id,
space_id=space_id,
region_name=region_name,
api_version=api_version,
)