From 6edb13373347dd74b5c9ede56c3cd37e0e4eab9c Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Sat, 20 Apr 2024 19:56:20 +0200 Subject: [PATCH] Added support for IBM watsonx.ai models --- litellm/__init__.py | 7 + litellm/llms/prompt_templates/factory.py | 44 +++ litellm/llms/watsonx.py | 480 +++++++++++++++++++++++ litellm/main.py | 38 ++ litellm/utils.py | 69 ++++ 5 files changed, 638 insertions(+) create mode 100644 litellm/llms/watsonx.py diff --git a/litellm/__init__.py b/litellm/__init__.py index b9d9891ca..95dd33f1c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -298,6 +298,7 @@ aleph_alpha_models: List = [] bedrock_models: List = [] deepinfra_models: List = [] perplexity_models: List = [] +watsonx_models: List = [] for key, value in model_cost.items(): if value.get("litellm_provider") == "openai": open_ai_chat_completion_models.append(key) @@ -342,6 +343,8 @@ for key, value in model_cost.items(): deepinfra_models.append(key) elif value.get("litellm_provider") == "perplexity": perplexity_models.append(key) + elif value.get("litellm_provider") == "watsonx": + watsonx_models.append(key) # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ @@ -478,6 +481,7 @@ model_list = ( + perplexity_models + maritalk_models + vertex_language_models + + watsonx_models ) provider_list: List = [ @@ -516,6 +520,7 @@ provider_list: List = [ "cloudflare", "xinference", "fireworks_ai", + "watsonx", "custom", # custom apis ] @@ -537,6 +542,7 @@ models_by_provider: dict = { "deepinfra": deepinfra_models, "perplexity": perplexity_models, "maritalk": maritalk_models, + "watsonx": watsonx_models, } # mapping for those models which have larger equivalents @@ -650,6 +656,7 @@ from .llms.bedrock import ( ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError +from .llms.watsonx import IBMWatsonXConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 176c81d5d..8ebd2a38f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -416,6 +416,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): prompt = default_pt(messages) return prompt +### IBM Granite + +def ibm_granite_pt(messages: list): + """ + IBM's Granite models uses the template: + <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} + + See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models + """ + return custom_prompt( + messages=messages, + role_dict={ + 'system': { + 'pre_message': '<|system|>\n', + 'post_message': '\n', + }, + 'user': { + 'pre_message': '<|user|>\n', + 'post_message': '\n', + }, + 'assistant': { + 'pre_message': '<|assistant|>\n', + 'post_message': '\n', + } + } + ).strip() ### ANTHROPIC ### @@ -1327,6 +1353,24 @@ def prompt_factory( return messages elif custom_llm_provider == "azure_text": return azure_text_pt(messages=messages) + elif custom_llm_provider == "watsonx": + if "granite" in model and "chat" in model: + # granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template + return ibm_granite_pt(messages=messages) + elif "ibm-mistral" in model: + # models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template + return mistral_instruct_pt(messages=messages) + elif "meta-llama/llama-3" in model and "instruct" in model: + return custom_prompt( + role_dict={ + "system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + }, + messages=messages, + initial_prompt_value="<|begin_of_text|>", + # final_prompt_value="\n", + ) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py new file mode 100644 index 000000000..7cb45730b --- /dev/null +++ b/litellm/llms/watsonx.py @@ -0,0 +1,480 @@ +import json, types, time +from typing import Callable, Optional, Any, Union, List + +import httpx +import litellm +from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse + +from .prompt_templates import factory as ptf + +class WatsonxError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://https://us-south.ml.cloud.ibm.com" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +class IBMWatsonXConfig: + """ + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation + (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) + + Supported params for all available watsonx.ai foundational models. + + - `decoding_method` (str): One of "greedy" or "sample" + + - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. + + - `max_new_tokens` (integer): Maximum length of the generated tokens. + + - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + + - `stop_sequences` (string[]): list of strings to use as stop sequences. + + - `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point. + + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + + - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + + - `repetition_penalty` (float): token repetition penalty during text generation. + + - `stream` (bool): If True, the model will return a stream of responses. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". + + - `truncate_input_tokens` (integer): Truncate input tokens to this length. + + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + + - `random_seed` (integer): Random seed for text generation. + + - `guardrails` (bool): Enable guardrails for harmful content. + + - `guardrails_hap_params` (dict): Guardrails for harmful content. + + - `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information. + + - `concurrency_limit` (integer): Maximum number of concurrent requests. + + - `async_mode` (bool): Enable async mode. + + - `verify` (bool): Verify the SSL certificate of calls to the watsonx url. + + - `validate` (bool): Validate the model_id at initialization. + + - `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance. + + - `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with. + """ + decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior + temperature: Optional[float] = None # + min_new_tokens: Optional[int] = None + max_new_tokens: Optional[int] = litellm.max_tokens + top_k: Optional[int] = None + top_p: Optional[float] = None + random_seed: Optional[int] = None # e.g 42 + repetition_penalty: Optional[float] = None + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] + time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds) + return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False} + truncate_input_tokens: Optional[int] = None # e.g 512 + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stream: Optional[bool] = False + # other inference params + guardrails: Optional[bool] = False # enable guardrails + guardrails_hap_params: Optional[dict] = None # guardrails for harmful content + guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information + concurrency_limit: Optional[int] = 10 # max number of concurrent requests + async_mode: Optional[bool] = False # enable async mode + verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url + validate: Optional[bool] = False # validate the model_id at initialization + model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance + watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with + + def __init__( + self, + decoding_method: Optional[str] = None, + temperature: Optional[float] = None, + min_new_tokens: Optional[int] = None, + max_new_tokens: Optional[ + int + ] = litellm.max_tokens, # petals requires max tokens to be set + top_k: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + repetition_penalty: Optional[float] = None, + stop_sequences: Optional[List[str]] = None, + time_limit: Optional[int] = None, + return_options: Optional[dict] = None, + truncate_input_tokens: Optional[int] = None, + length_penalty: Optional[dict] = None, + stream: Optional[bool] = False, + guardrails: Optional[bool] = False, + guardrails_hap_params: Optional[dict] = None, + guardrails_pii_params: Optional[dict] = None, + concurrency_limit: Optional[int] = 10, + async_mode: Optional[bool] = False, + verify: Optional[Union[bool,str]] = None, + validate: Optional[bool] = False, + model_inference: Optional[object] = None, + watsonx_client: Optional[object] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + ] + + +def init_watsonx_model( + model_id: str, + url: Optional[str] = None, + api_key: Optional[str] = None, + project_id: Optional[str] = None, + space_id: Optional[str] = None, + wx_credentials: Optional[dict] = None, + region_name: Optional[str] = None, + verify: Optional[Union[bool,str]] = None, + validate: Optional[bool] = False, + watsonx_client: Optional[object] = None, + model_params: Optional[dict] = None, +): + """ + Initialize a watsonx.ai model for inference. + + Args: + + model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/' and the space_id to the deploymend space should be provided. + url (str): The URL of the watsonx.ai instance. + api_key (str): The API key for the watsonx.ai instance. + project_id (str): The project ID for the watsonx.ai instance. + space_id (str): The space ID for the deployment space. + wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance. + region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south'). + verify (bool): Whether to verify the SSL certificate of calls to the watsonx url. + validate (bool): Whether to validate the model_id at initialization. + watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client. + model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters). + """ + + from ibm_watsonx_ai import APIClient + from ibm_watsonx_ai.foundation_models import ModelInference + + + if wx_credentials is not None: + if 'apikey' not in wx_credentials and 'api_key' in wx_credentials: + wx_credentials['apikey'] = wx_credentials.pop('api_key') + if 'apikey' not in wx_credentials: + raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials") + + if url is None: + url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL") + if api_key is None: + api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY") + if project_id is None: + project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID") + if region_name is None: + region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME") + if space_id is None: + space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID") + + + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = (url, api_key, project_id, space_id, region_name) + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + params_to_check[i] = get_secret(param) + # Assign updated values back to parameters + url, api_key, project_id, space_id, region_name = params_to_check + + ### SET WATSONX URL + if url is not None or watsonx_client is not None or wx_credentials is not None: + pass + elif region_name is not None: + url = f"https://{region_name}.ml.cloud.ibm.com" + else: + raise WatsonxError( + message="Watsonx URL not set: set WX_URL env variable or in .env file", + status_code=401, + ) + if watsonx_client is not None and project_id is None: + project_id = watsonx_client.project_id + + if model_id.startswith("deployment/"): + # deployment models are passed in as 'deployment/' + assert space_id is not None, "space_id is required for deployment models" + deployment_id = '/'.join(model_id.split("/")[1:]) + model_id = None + else: + deployment_id = None + + if watsonx_client is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + api_client=watsonx_client, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + elif wx_credentials is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + credentials=wx_credentials, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + elif api_key is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + credentials={ + "apikey": api_key, + "url": url, + }, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + else: + raise WatsonxError(500, "WatsonX credentials not passed or could not be found.") + + + return model + + +def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): + # handle anthropic prompts and amazon titan prompts + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_dict = custom_prompt_dict[model] + prompt = ptf.custom_prompt( + messages=messages, + role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""), + final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), + bos_token=model_prompt_dict.get("bos_token", ""), + eos_token=model_prompt_dict.get("eos_token", ""), + ) + return prompt + elif provider == "ibm": + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) + elif provider == "ibm-mistralai": + prompt = ptf.mistral_instruct_pt(messages=messages) + else: + prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx') + return prompt + + +""" +IBM watsonx.ai AUTH Keys/Vars +os.environ['WX_URL'] = "" +os.environ['WX_API_KEY'] = "" +os.environ['WX_PROJECT_ID'] = "" +""" + +def completion( + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params:Optional[dict]=None, + litellm_params:Optional[dict]=None, + logger_fn=None, + timeout:float=None, +): + from ibm_watsonx_ai.foundation_models import Model, ModelInference + + try: + stream = optional_params.pop("stream", False) + extra_generate_params = dict( + guardrails=optional_params.pop("guardrails", False), + guardrails_hap_params=optional_params.pop("guardrails_hap_params", None), + guardrails_pii_params=optional_params.pop("guardrails_pii_params", None), + concurrency_limit=optional_params.pop("concurrency_limit", 10), + async_mode=optional_params.pop("async_mode", False), + ) + if timeout is not None and optional_params.get("time_limit") is None: + # the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds) + optional_params['time_limit'] = max(0, int(timeout*1000)) + extra_body_params = optional_params.pop("extra_body", {}) + optional_params.update(extra_body_params) + # LOAD CONFIG + config = IBMWatsonXConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + model_inference = optional_params.pop("model_inference", None) + if model_inference is None: + # INIT MODEL + model_client:ModelInference = init_watsonx_model( + model_id=model, + url=optional_params.pop("url", None), + api_key=optional_params.pop("api_key", None), + project_id=optional_params.pop("project_id", None), + space_id=optional_params.pop("space_id", None), + wx_credentials=optional_params.pop("wx_credentials", None), + region_name=optional_params.pop("region_name", None), + verify=optional_params.pop("verify", None), + validate=optional_params.pop("validate", False), + watsonx_client=optional_params.pop("watsonx_client", None), + model_params=optional_params, + ) + else: + model_client:ModelInference = model_inference + model = model_client.model_id + + # MAKE PROMPT + provider = model.split("/")[0] + model_name = '/'.join(model.split("/")[1:]) + prompt = convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) + ## COMPLETION CALL + if stream is True: + request_str = ( + "response = model.generate_text_stream(\n" + f"\tprompt={prompt},\n" + "\traw_response=True\n)" + ) + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + # remove params that are not needed for streaming + del extra_generate_params["async_mode"] + del extra_generate_params["concurrency_limit"] + # make generate call + response = model_client.generate_text_stream( + prompt=prompt, + raw_response=True, + **extra_generate_params + ) + return litellm.CustomStreamWrapper( + response, + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + else: + try: + ## LOGGING + request_str = ( + "response = model.generate(\n" + f"\tprompt={prompt},\n" + "\traw_response=True\n)" + ) + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = model_client.generate( + prompt=prompt, + **extra_generate_params + ) + except Exception as e: + raise WatsonxError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=json.dumps(response), + additional_args={"complete_input_dict": optional_params}, + ) + print_verbose(f"raw model_response: {response}") + ## BUILD RESPONSE OBJECT + output_text = response['results'][0]['generated_text'] + + try: + if ( + len(output_text) > 0 + and hasattr(model_response.choices[0], "message") + ): + model_response["choices"][0]["message"]["content"] = output_text + model_response["finish_reason"] = response['results'][0]['stop_reason'] + prompt_tokens = response['results'][0]['input_token_count'] + completion_tokens = response['results'][0]['generated_token_count'] + else: + raise Exception() + except: + raise WatsonxError( + message=json.dumps(output_text), + status_code=500, + ) + model_response['created'] = int(time.time()) + model_response['model'] = model_name + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response + except WatsonxError as e: + raise e + except Exception as e: + raise WatsonxError(status_code=500, message=str(e)) + + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 65696b3c0..753193f96 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -63,6 +63,7 @@ from .llms import ( vertex_ai, vertex_ai_anthropic, maritalk, + watsonx, ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion @@ -1858,6 +1859,43 @@ def completion( ## RESPONSE OBJECT response = response + elif custom_llm_provider == "watsonx": + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = watsonx.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + timeout=timeout, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="watsonx", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + ## RESPONSE OBJECT + response = response elif custom_llm_provider == "vllm": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = vllm.completion( diff --git a/litellm/utils.py b/litellm/utils.py index e230675e6..19118acbe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5331,6 +5331,45 @@ def get_optional_params( optional_params["extra_body"] = ( extra_body # openai client supports `extra_body` param ) + elif custom_llm_provider == "watsonx": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + if max_tokens is not None: + optional_params["max_new_tokens"] = max_tokens + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if frequency_penalty is not None: + optional_params["repetition_penalty"] = frequency_penalty + if seed is not None: + optional_params["random_seed"] = seed + if stop is not None: + optional_params["stop_sequences"] = stop + + # WatsonX-only parameters + extra_body = {} + if "decoding_method" in passed_params: + extra_body["decoding_method"] = passed_params.pop("decoding_method") + if "min_tokens" in passed_params or "min_new_tokens" in passed_params: + extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens")) + if "top_k" in passed_params: + extra_body["top_k"] = passed_params.pop("top_k") + if "truncate_input_tokens" in passed_params: + extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens") + if "length_penalty" in passed_params: + extra_body["length_penalty"] = passed_params.pop("length_penalty") + if "time_limit" in passed_params: + extra_body["time_limit"] = passed_params.pop("time_limit") + if "return_options" in passed_params: + extra_body["return_options"] = passed_params.pop("return_options") + optional_params["extra_body"] = ( + extra_body # openai client supports `extra_body` param + ) else: # assume passing in params for openai/azure openai print_verbose( f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" @@ -5688,6 +5727,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "frequency_penalty", "presence_penalty", ] + elif custom_llm_provider == "watsonx": + return litellm.IBMWatsonXConfig().get_supported_openai_params() def get_formatted_prompt( @@ -5914,6 +5955,8 @@ def get_llm_provider( model in litellm.bedrock_models or model in litellm.bedrock_embedding_models ): custom_llm_provider = "bedrock" + elif model in litellm.watsonx_models: + custom_llm_provider = "watsonx" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" @@ -9590,6 +9633,26 @@ class CustomStreamWrapper: "is_finished": chunk["is_finished"], "finish_reason": finish_reason, } + + def handle_watsonx_stream(self, chunk): + try: + if isinstance(chunk, dict): + pass + elif isinstance(chunk, str): + chunk = json.loads(chunk) + result = chunk.get("results", []) + if len(result) > 0: + text = result[0].get("generated_text", "") + finish_reason = result[0].get("stop_reason") + is_finished = finish_reason != 'not_finished' + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + return "" + except Exception as e: + raise e def model_response_creator(self): model_response = ModelResponse(stream=True, model=self.model) @@ -9845,6 +9908,12 @@ class CustomStreamWrapper: print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "watsonx": + response_obj = self.handle_watsonx_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"]