diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index aa0cb32df1..28061919e6 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -13,7 +13,7 @@ from .prompt_templates import factory as ptf class WatsonXAIError(Exception): - def __init__(self, status_code, message, url: str = None): + def __init__(self, status_code, message, url: Optional[str] = None): self.status_code = status_code self.message = message url = url or "https://https://us-south.ml.cloud.ibm.com" @@ -73,7 +73,6 @@ class IBMWatsonXAIConfig: repetition_penalty: Optional[float] = None truncate_input_tokens: Optional[int] = None include_stop_sequences: Optional[bool] = False - return_options: Optional[dict] = None return_options: Optional[Dict[str, bool]] = None random_seed: Optional[int] = None # e.g 42 moderations: Optional[dict] = None @@ -161,6 +160,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): ) return prompt + class WatsonXAIEndpoint(str, Enum): TEXT_GENERATION = "/ml/v1/text/generation" TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" @@ -171,6 +171,7 @@ class WatsonXAIEndpoint(str, Enum): EMBEDDINGS = "/ml/v1/text/embeddings" PROMPTS = "/ml/v1/prompts" + class IBMWatsonXAI(BaseLLM): """ Class to interface with IBM Watsonx.ai API for text generation and embeddings. @@ -180,7 +181,6 @@ class IBMWatsonXAI(BaseLLM): api_version = "2024-03-13" - def __init__(self) -> None: super().__init__() @@ -190,7 +190,7 @@ class IBMWatsonXAI(BaseLLM): prompt: str, stream: bool, optional_params: dict, - print_verbose: Callable = None, + print_verbose: Optional[Callable] = None, ) -> dict: """ Get the request parameters for text generation. @@ -224,9 +224,9 @@ class IBMWatsonXAI(BaseLLM): ) deployment_id = "/".join(model_id.split("/")[1:]) endpoint = ( - WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM + WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value if stream - else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION + else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value ) endpoint = endpoint.format(deployment_id=deployment_id) else: @@ -239,27 +239,40 @@ class IBMWatsonXAI(BaseLLM): ) url = api_params["url"].rstrip("/") + endpoint return dict( - method="POST", url=url, headers=headers, - json=payload, params=request_params + method="POST", url=url, headers=headers, json=payload, params=request_params ) - def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: + def _get_api_params( + self, params: dict, print_verbose: Optional[Callable] = None + ) -> dict: """ Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. """ # Load auth variables from params - url = params.pop("url", None) + url = params.pop("url", params.pop("api_base", params.pop("base_url", None))) api_key = params.pop("apikey", None) token = params.pop("token", None) - project_id = params.pop("project_id", None) # watsonx.ai project_id + project_id = params.pop( + "project_id", params.pop("watsonx_project", None) + ) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params space_id = params.pop("space_id", None) # watsonx.ai deployment space_id region_name = params.pop("region_name", params.pop("region", None)) - wx_credentials = params.pop("wx_credentials", None) + if region_name is None: + region_name = params.pop( + "watsonx_region_name", params.pop("watsonx_region", None) + ) # consistent with how vertex ai + aws regions are accepted + wx_credentials = params.pop( + "wx_credentials", + params.pop( + "watsonx_credentials", None + ), # follow {provider}_credentials, same as vertex ai + ) api_version = params.pop("api_version", IBMWatsonXAI.api_version) # Load auth variables from environment variables if url is None: url = ( - get_secret("WATSONX_URL") + get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' + or get_secret("WATSONX_URL") or get_secret("WX_URL") or get_secret("WML_URL") ) @@ -297,7 +310,12 @@ class IBMWatsonXAI(BaseLLM): api_key = wx_credentials.get( "apikey", wx_credentials.get("api_key", api_key) ) - token = wx_credentials.get("token", token) + token = wx_credentials.get( + "token", + wx_credentials.get( + "watsonx_token", token + ), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..' + ) # verify that all required credentials are present if url is None: @@ -342,10 +360,10 @@ class IBMWatsonXAI(BaseLLM): print_verbose: Callable, encoding, logging_obj, - optional_params: Optional[dict] = None, + optional_params: dict, litellm_params: Optional[dict] = None, logger_fn=None, - timeout: float = None, + timeout: Optional[float] = None, ): """ Send a text generation request to the IBM Watsonx.ai API. @@ -379,10 +397,14 @@ class IBMWatsonXAI(BaseLLM): model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["created"] = int(time.time()) model_response["model"] = model - model_response.usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + setattr( + model_response, + "usage", + Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) return model_response @@ -525,7 +547,7 @@ class IBMWatsonXAI(BaseLLM): logging_obj: Any, stream: bool = False, input: Optional[Any] = None, - timeout: float = None, + timeout: Optional[float] = None, ): request_str = ( f"response = {request_params['method']}(\n" @@ -535,14 +557,14 @@ class IBMWatsonXAI(BaseLLM): ) logging_obj.pre_call( input=input, - api_key=request_params['headers'].get("Authorization"), + api_key=request_params["headers"].get("Authorization"), additional_args={ - "complete_input_dict": request_params['json'], + "complete_input_dict": request_params["json"], "request_str": request_str, }, ) if timeout: - request_params['timeout'] = timeout + request_params["timeout"] = timeout try: if stream: resp = requests.request( @@ -560,10 +582,10 @@ class IBMWatsonXAI(BaseLLM): if not stream: logging_obj.post_call( input=input, - api_key=request_params['headers'].get("Authorization"), + api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params['json'], + "complete_input_dict": request_params["json"], }, ) diff --git a/litellm/main.py b/litellm/main.py index 41794ccd50..454f7f7169 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1872,7 +1872,7 @@ def completion( model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, - litellm_params=litellm_params, + litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, logging_obj=logging,