diff --git a/litellm/__init__.py b/litellm/__init__.py index 95dd33f1c..a7c17d53c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -656,7 +656,7 @@ from .llms.bedrock import ( ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError -from .llms.watsonx import IBMWatsonXConfig +from .llms.watsonx import IBMWatsonXAIConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 7cb45730b..38837ddb2 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,27 +1,31 @@ -import json, types, time -from typing import Callable, Optional, Any, Union, List +import json, types, time # noqa: E401 +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Any, Union, List import httpx +import requests import litellm -from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse +from litellm.utils import ModelResponse, get_secret, Usage +from .base import BaseLLM from .prompt_templates import factory as ptf -class WatsonxError(Exception): - def __init__(self, status_code, message): + +class WatsonXAIError(Exception): + def __init__(self, status_code, message, url: str = None): self.status_code = status_code self.message = message - self.request = httpx.Request( - method="POST", url="https://https://us-south.ml.cloud.ibm.com" - ) + url = url or "https://https://us-south.ml.cloud.ibm.com" + self.request = httpx.Request(method="POST", url=url) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class IBMWatsonXConfig: + +class IBMWatsonXAIConfig: """ - Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) Supported params for all available watsonx.ai foundational models. @@ -34,96 +38,64 @@ class IBMWatsonXConfig: - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + - `stop_sequences` (string[]): list of strings to use as stop sequences. - - `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point. - - - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. - - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + - `repetition_penalty` (float): token repetition penalty during text generation. - - `stream` (bool): If True, the model will return a stream of responses. - - - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". - - `truncate_input_tokens` (integer): Truncate input tokens to this length. - - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. - `random_seed` (integer): Random seed for text generation. - - `guardrails` (bool): Enable guardrails for harmful content. + - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. - - `guardrails_hap_params` (dict): Guardrails for harmful content. - - - `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information. - - - `concurrency_limit` (integer): Maximum number of concurrent requests. - - - `async_mode` (bool): Enable async mode. - - - `verify` (bool): Verify the SSL certificate of calls to the watsonx url. - - - `validate` (bool): Validate the model_id at initialization. - - - `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance. - - - `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with. + - `stream` (bool): If True, the model will return a stream of responses. """ - decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior - temperature: Optional[float] = None # + + decoding_method: Optional[str] = "sample" + temperature: Optional[float] = None + max_new_tokens: Optional[int] = None # litellm.max_tokens min_new_tokens: Optional[int] = None - max_new_tokens: Optional[int] = litellm.max_tokens + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] top_k: Optional[int] = None top_p: Optional[float] = None - random_seed: Optional[int] = None # e.g 42 repetition_penalty: Optional[float] = None - stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] - time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds) - return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False} - truncate_input_tokens: Optional[int] = None # e.g 512 - length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + truncate_input_tokens: Optional[int] = None + include_stop_sequences: Optional[bool] = False + return_options: Optional[dict] = None + return_options: Optional[Dict[str, bool]] = None + random_seed: Optional[int] = None # e.g 42 + moderations: Optional[dict] = None stream: Optional[bool] = False - # other inference params - guardrails: Optional[bool] = False # enable guardrails - guardrails_hap_params: Optional[dict] = None # guardrails for harmful content - guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information - concurrency_limit: Optional[int] = 10 # max number of concurrent requests - async_mode: Optional[bool] = False # enable async mode - verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url - validate: Optional[bool] = False # validate the model_id at initialization - model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance - watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with def __init__( self, decoding_method: Optional[str] = None, temperature: Optional[float] = None, + max_new_tokens: Optional[int] = None, min_new_tokens: Optional[int] = None, - max_new_tokens: Optional[ - int - ] = litellm.max_tokens, # petals requires max tokens to be set + length_penalty: Optional[dict] = None, + stop_sequences: Optional[List[str]] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - random_seed: Optional[int] = None, repetition_penalty: Optional[float] = None, - stop_sequences: Optional[List[str]] = None, - time_limit: Optional[int] = None, - return_options: Optional[dict] = None, truncate_input_tokens: Optional[int] = None, - length_penalty: Optional[dict] = None, - stream: Optional[bool] = False, - guardrails: Optional[bool] = False, - guardrails_hap_params: Optional[dict] = None, - guardrails_pii_params: Optional[dict] = None, - concurrency_limit: Optional[int] = 10, - async_mode: Optional[bool] = False, - verify: Optional[Union[bool,str]] = None, - validate: Optional[bool] = False, - model_inference: Optional[object] = None, - watsonx_client: Optional[object] = None, + include_stop_sequences: Optional[bool] = None, + return_options: Optional[dict] = None, + random_seed: Optional[int] = None, + moderations: Optional[dict] = None, + stream: Optional[bool] = None, + **kwargs, ) -> None: locals_ = locals() for key, value in locals_.items(): @@ -150,143 +122,16 @@ class IBMWatsonXConfig: def get_supported_openai_params(self): return [ - "temperature", # equivalent to temperature - "max_tokens", # equivalent to max_new_tokens - "top_p", # equivalent to top_p - "frequency_penalty", # equivalent to repetition_penalty - "stop", # equivalent to stop_sequences - "seed", # equivalent to random_seed - "stream", # equivalent to stream + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream ] -def init_watsonx_model( - model_id: str, - url: Optional[str] = None, - api_key: Optional[str] = None, - project_id: Optional[str] = None, - space_id: Optional[str] = None, - wx_credentials: Optional[dict] = None, - region_name: Optional[str] = None, - verify: Optional[Union[bool,str]] = None, - validate: Optional[bool] = False, - watsonx_client: Optional[object] = None, - model_params: Optional[dict] = None, -): - """ - Initialize a watsonx.ai model for inference. - - Args: - - model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/' and the space_id to the deploymend space should be provided. - url (str): The URL of the watsonx.ai instance. - api_key (str): The API key for the watsonx.ai instance. - project_id (str): The project ID for the watsonx.ai instance. - space_id (str): The space ID for the deployment space. - wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance. - region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south'). - verify (bool): Whether to verify the SSL certificate of calls to the watsonx url. - validate (bool): Whether to validate the model_id at initialization. - watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client. - model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters). - """ - - from ibm_watsonx_ai import APIClient - from ibm_watsonx_ai.foundation_models import ModelInference - - - if wx_credentials is not None: - if 'apikey' not in wx_credentials and 'api_key' in wx_credentials: - wx_credentials['apikey'] = wx_credentials.pop('api_key') - if 'apikey' not in wx_credentials: - raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials") - - if url is None: - url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL") - if api_key is None: - api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY") - if project_id is None: - project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID") - if region_name is None: - region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME") - if space_id is None: - space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID") - - - ## CHECK IS 'os.environ/' passed in - # Define the list of parameters to check - params_to_check = (url, api_key, project_id, space_id, region_name) - # Iterate over parameters and update if needed - for i, param in enumerate(params_to_check): - if param and param.startswith("os.environ/"): - params_to_check[i] = get_secret(param) - # Assign updated values back to parameters - url, api_key, project_id, space_id, region_name = params_to_check - - ### SET WATSONX URL - if url is not None or watsonx_client is not None or wx_credentials is not None: - pass - elif region_name is not None: - url = f"https://{region_name}.ml.cloud.ibm.com" - else: - raise WatsonxError( - message="Watsonx URL not set: set WX_URL env variable or in .env file", - status_code=401, - ) - if watsonx_client is not None and project_id is None: - project_id = watsonx_client.project_id - - if model_id.startswith("deployment/"): - # deployment models are passed in as 'deployment/' - assert space_id is not None, "space_id is required for deployment models" - deployment_id = '/'.join(model_id.split("/")[1:]) - model_id = None - else: - deployment_id = None - - if watsonx_client is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - api_client=watsonx_client, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - elif wx_credentials is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - credentials=wx_credentials, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - elif api_key is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - credentials={ - "apikey": api_key, - "url": url, - }, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - else: - raise WatsonxError(500, "WatsonX credentials not passed or could not be found.") - - - return model - - def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): # handle anthropic prompts and amazon titan prompts if model in custom_prompt_dict: @@ -294,8 +139,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): model_prompt_dict = custom_prompt_dict[model] prompt = ptf.custom_prompt( messages=messages, - role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")), - initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""), + role_dict=model_prompt_dict.get( + "role_dict", model_prompt_dict.get("roles") + ), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), bos_token=model_prompt_dict.get("bos_token", ""), eos_token=model_prompt_dict.get("eos_token", ""), @@ -308,173 +155,408 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): elif provider == "ibm-mistralai": prompt = ptf.mistral_instruct_pt(messages=messages) else: - prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx') + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) return prompt -""" -IBM watsonx.ai AUTH Keys/Vars -os.environ['WX_URL'] = "" -os.environ['WX_API_KEY'] = "" -os.environ['WX_PROJECT_ID'] = "" -""" +class IBMWatsonXAI(BaseLLM): + """ + Class to interface with IBM Watsonx.ai API for text generation and embeddings. -def completion( - model: str, - messages: list, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params:Optional[dict]=None, - litellm_params:Optional[dict]=None, - logger_fn=None, - timeout:float=None, -): - from ibm_watsonx_ai.foundation_models import Model, ModelInference + Reference: https://cloud.ibm.com/apidocs/watsonx-ai + """ - try: - stream = optional_params.pop("stream", False) - extra_generate_params = dict( - guardrails=optional_params.pop("guardrails", False), - guardrails_hap_params=optional_params.pop("guardrails_hap_params", None), - guardrails_pii_params=optional_params.pop("guardrails_pii_params", None), - concurrency_limit=optional_params.pop("concurrency_limit", 10), - async_mode=optional_params.pop("async_mode", False), - ) - if timeout is not None and optional_params.get("time_limit") is None: - # the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds) - optional_params['time_limit'] = max(0, int(timeout*1000)) + api_version = "2024-03-13" + _text_gen_endpoint = "/ml/v1/text/generation" + _text_gen_stream_endpoint = "/ml/v1/text/generation_stream" + _deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation" + _deployment_text_gen_stream_endpoint = ( + "/ml/v1/deployments/{deployment_id}/text/generation_stream" + ) + _embeddings_endpoint = "/ml/v1/text/embeddings" + _prompts_endpoint = "/ml/v1/prompts" + + def __init__(self) -> None: + super().__init__() + + def _prepare_text_generation_req( + self, + model_id: str, + prompt: str, + stream: bool, + optional_params: dict, + print_verbose: Callable = None, + ) -> httpx.Request: + """ + Get the request parameters for text generation. + """ + api_params = self._get_api_params(optional_params, print_verbose=print_verbose) + # build auth headers + api_token = api_params.get("token") + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } extra_body_params = optional_params.pop("extra_body", {}) optional_params.update(extra_body_params) - # LOAD CONFIG - config = IBMWatsonXConfig.get_config() + # init the payload to the text generation call + payload = { + "input": prompt, + "moderations": optional_params.pop("moderations", {}), + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + # text generation endpoint deployment or model / stream or not + if model_id.startswith("deployment/"): + # deployment models are passed in as 'deployment/' + if api_params.get("space_id") is None: + raise WatsonXAIError( + status_code=401, + url=api_params["url"], + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + deployment_id = "/".join(model_id.split("/")[1:]) + endpoint = ( + self._deployment_text_gen_stream_endpoint + if stream + else self._deployment_text_gen_endpoint + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + payload["model_id"] = model_id + payload["project_id"] = api_params["project_id"] + endpoint = ( + self._text_gen_stream_endpoint if stream else self._text_gen_endpoint + ) + url = api_params["url"].rstrip("/") + endpoint + return httpx.Request( + "POST", url, headers=headers, json=payload, params=request_params + ) + + def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: + """ + Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. + """ + # Load auth variables from params + url = params.pop("url", None) + api_key = params.pop("apikey", None) + token = params.pop("token", None) + project_id = params.pop("project_id", None) # watsonx.ai project_id + space_id = params.pop("space_id", None) # watsonx.ai deployment space_id + region_name = params.pop("region_name", params.pop("region", None)) + wx_credentials = params.pop("wx_credentials", None) + api_version = params.pop("api_version", IBMWatsonXAI.api_version) + # Load auth variables from environment variables + if url is None: + url = ( + get_secret("WATSONX_URL") + or get_secret("WX_URL") + or get_secret("WML_URL") + ) + if api_key is None: + api_key = get_secret("WATSONX_API_KEY") or get_secret("WX_API_KEY") + if token is None: + token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") + if project_id is None: + project_id = ( + get_secret("WATSONX_PROJECT_ID") + or get_secret("WX_PROJECT_ID") + or get_secret("PROJECT_ID") + ) + if region_name is None: + region_name = ( + get_secret("WATSONX_REGION") + or get_secret("WX_REGION") + or get_secret("REGION") + ) + if space_id is None: + space_id = ( + get_secret("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret("WATSONX_SPACE_ID") + or get_secret("WX_SPACE_ID") + or get_secret("SPACE_ID") + ) + + # credentials parsing + if wx_credentials is not None: + url = wx_credentials.get("url", url) + api_key = wx_credentials.get( + "apikey", wx_credentials.get("api_key", api_key) + ) + token = wx_credentials.get("token", token) + + # verify that all required credentials are present + if url is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.", + ) + if token is None and api_key is not None: + # generate the auth token + if print_verbose: + print_verbose("Generating IAM token for Watsonx.ai") + token = self.generate_iam_token(api_key) + elif token is None and api_key is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.", + ) + if project_id is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + ) + + return { + "url": url, + "api_key": api_key, + "token": token, + "project_id": project_id, + "space_id": space_id, + "region_name": region_name, + "api_version": api_version, + } + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: Optional[dict] = None, + litellm_params: Optional[dict] = None, + logger_fn=None, + timeout: float = None, + ): + """ + Send a text generation request to the IBM Watsonx.ai API. + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + """ + stream = optional_params.pop("stream", False) + + # Load default configs + config = IBMWatsonXAIConfig.get_config() for k, v in config.items(): if k not in optional_params: optional_params[k] = v - model_inference = optional_params.pop("model_inference", None) - if model_inference is None: - # INIT MODEL - model_client:ModelInference = init_watsonx_model( - model_id=model, - url=optional_params.pop("url", None), - api_key=optional_params.pop("api_key", None), - project_id=optional_params.pop("project_id", None), - space_id=optional_params.pop("space_id", None), - wx_credentials=optional_params.pop("wx_credentials", None), - region_name=optional_params.pop("region_name", None), - verify=optional_params.pop("verify", None), - validate=optional_params.pop("validate", False), - watsonx_client=optional_params.pop("watsonx_client", None), - model_params=optional_params, - ) - else: - model_client:ModelInference = model_inference - model = model_client.model_id - - # MAKE PROMPT + # Make prompt to send to model provider = model.split("/")[0] - model_name = '/'.join(model.split("/")[1:]) + # model_name = "/".join(model.split("/")[1:]) prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) - ## COMPLETION CALL - if stream is True: - request_str = ( - "response = model.generate_text_stream(\n" - f"\tprompt={prompt},\n" - "\traw_response=True\n)" - ) - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - # remove params that are not needed for streaming - del extra_generate_params["async_mode"] - del extra_generate_params["concurrency_limit"] - # make generate call - response = model_client.generate_text_stream( - prompt=prompt, - raw_response=True, - **extra_generate_params - ) - return litellm.CustomStreamWrapper( - response, - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) - else: - try: - ## LOGGING - request_str = ( - "response = model.generate(\n" - f"\tprompt={prompt},\n" - "\traw_response=True\n)" - ) - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - response = model_client.generate( - prompt=prompt, - **extra_generate_params - ) - except Exception as e: - raise WatsonxError(status_code=500, message=str(e)) - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=json.dumps(response), - additional_args={"complete_input_dict": optional_params}, - ) - print_verbose(f"raw model_response: {response}") - ## BUILD RESPONSE OBJECT - output_text = response['results'][0]['generated_text'] + def process_text_request(request: httpx.Request) -> ModelResponse: + with self._manage_response( + request, logging_obj=logging_obj, input=prompt, timeout=timeout + ) as resp: + json_resp = resp.json() + + generated_text = json_resp["results"][0]["generated_text"] + prompt_tokens = json_resp["results"][0]["input_token_count"] + completion_tokens = json_resp["results"][0]["generated_token_count"] + model_response["choices"][0]["message"]["content"] = generated_text + model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] + model_response["created"] = int(time.time()) + model_response["model"] = model + model_response.usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + return model_response + + def process_stream_request( + request: httpx.Request, + ) -> litellm.CustomStreamWrapper: + # stream the response - generated chunks will be handled + # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream + with self._manage_response( + request, + logging_obj=logging_obj, + stream=True, + input=prompt, + timeout=timeout, + ) as resp: + response = litellm.CustomStreamWrapper( + resp.iter_lines(), + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return response try: - if ( - len(output_text) > 0 - and hasattr(model_response.choices[0], "message") - ): - model_response["choices"][0]["message"]["content"] = output_text - model_response["finish_reason"] = response['results'][0]['stop_reason'] - prompt_tokens = response['results'][0]['input_token_count'] - completion_tokens = response['results'][0]['generated_token_count'] - else: - raise Exception() - except: - raise WatsonxError( - message=json.dumps(output_text), - status_code=500, + ## Get the response from the model + request = self._prepare_text_generation_req( + model_id=model, + prompt=prompt, + stream=stream, + optional_params=optional_params, + print_verbose=print_verbose, ) - model_response['created'] = int(time.time()) - model_response['model'] = model_name - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + if stream: + return process_stream_request(request) + else: + return process_text_request(request) + except WatsonXAIError as e: + raise e + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + + def embedding( + self, + model: str, + input: Union[list, str], + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + ): + """ + Send a text embedding request to the IBM Watsonx.ai API. + """ + if optional_params is None: + optional_params = {} + # Load default configs + config = IBMWatsonXAIConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + # Load auth variables from environment variables + if isinstance(input, str): + input = [input] + if api_key is not None: + optional_params["api_key"] = api_key + api_params = self._get_api_params(optional_params) + # build auth headers + api_token = api_params.get("token") + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + # init the payload to the text generation call + payload = { + "inputs": input, + "model_id": model, + "project_id": api_params["project_id"], + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + url = api_params["url"].rstrip("/") + self._embeddings_endpoint + request = httpx.Request( + "POST", url, headers=headers, json=payload, params=request_params + ) + with self._manage_response( + request, logging_obj=logging_obj, input=input + ) as resp: + json_resp = resp.json() + + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens ) - model_response.usage = usage return model_response - except WatsonxError as e: - raise e - except Exception as e: - raise WatsonxError(status_code=500, message=str(e)) + def generate_iam_token(self, api_key=None, **params): + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + if api_key is None: + api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY") + if api_key is None: + raise ValueError("API key is required") + headers["Accept"] = "application/json" + data = { + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + } + response = httpx.post( + "https://iam.cloud.ibm.com/identity/token", data=data, headers=headers + ) + response.raise_for_status() + json_data = response.json() + iam_access_token = json_data["access_token"] + self.token = iam_access_token + return iam_access_token -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass \ No newline at end of file + @contextmanager + def _manage_response( + self, + request: httpx.Request, + logging_obj: Any, + stream: bool = False, + input: Optional[Any] = None, + timeout: float = None, + ): + request_str = ( + f"response = {request.method}(\n" + f"\turl={request.url},\n" + f"\tjson={request.content.decode()},\n" + f")" + ) + json_input = json.loads(request.content.decode()) + headers = dict(request.headers) + logging_obj.pre_call( + input=input, + api_key=request.headers.get("Authorization"), + additional_args={ + "complete_input_dict": json_input, + "request_str": request_str, + }, + ) + try: + if stream: + resp = requests.request( + method=request.method, + url=str(request.url), + headers=headers, + json=json_input, + stream=True, + timeout=timeout, + ) + # resp.raise_for_status() + yield resp + else: + resp = requests.request( + method=request.method, + url=str(request.url), + headers=headers, + json=json_input, + timeout=timeout, + ) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + logging_obj.post_call( + input=input, + api_key=request.headers.get("Authorization"), + original_response=json.dumps(resp.json()), + additional_args={ + "status_code": resp.status_code, + "complete_input_dict": request, + }, + ) diff --git a/litellm/main.py b/litellm/main.py index b61df8c12..8f357b834 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1862,7 +1862,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonx.completion( + response = watsonx.IBMWatsonXAI().completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -2976,6 +2976,15 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "watsonx": + response = watsonx.IBMWatsonXAI().embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/utils.py b/litellm/utils.py index 836587fb1..89061c3bf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "presence_penalty", ] elif custom_llm_provider == "watsonx": - return litellm.IBMWatsonXConfig().get_supported_openai_params() + return litellm.IBMWatsonXAIConfig().get_supported_openai_params() def get_formatted_prompt( @@ -9682,20 +9682,31 @@ class CustomStreamWrapper: def handle_watsonx_stream(self, chunk): try: if isinstance(chunk, dict): - pass - elif isinstance(chunk, str): - chunk = json.loads(chunk) - result = chunk.get("results", []) - if len(result) > 0: - text = result[0].get("generated_text", "") - finish_reason = result[0].get("stop_reason") + parsed_response = chunk + elif isinstance(chunk, (str, bytes)): + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if 'generated_text' in chunk: + response = chunk.replace('data: ', '').strip() + parsed_response = json.loads(response) + else: + return {"text": "", "is_finished": False} + else: + print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") + raise ValueError(f"Unable to parse response. Original response: {chunk}") + results = parsed_response.get("results", []) + if len(results) > 0: + text = results[0].get("generated_text", "") + finish_reason = results[0].get("stop_reason") is_finished = finish_reason != 'not_finished' return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, + "prompt_tokens": results[0].get("input_token_count", None), + "completion_tokens": results[0].get("generated_token_count", None), } - return "" + return {"text": "", "is_finished": False} except Exception as e: raise e @@ -9957,6 +9968,15 @@ class CustomStreamWrapper: response_obj = self.handle_watsonx_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj.get("prompt_tokens") is not None: + prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) + model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) + if response_obj.get("completion_tokens") is not None: + model_response.usage.completion_tokens = response_obj["completion_tokens"] + model_response.usage.total_tokens = ( + getattr(model_response.usage, "prompt_tokens", 0) + + getattr(model_response.usage, "completion_tokens", 0) + ) if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai":