From 6edb13373347dd74b5c9ede56c3cd37e0e4eab9c Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Sat, 20 Apr 2024 19:56:20 +0200 Subject: [PATCH 01/11] Added support for IBM watsonx.ai models --- litellm/__init__.py | 7 + litellm/llms/prompt_templates/factory.py | 44 +++ litellm/llms/watsonx.py | 480 +++++++++++++++++++++++ litellm/main.py | 38 ++ litellm/utils.py | 69 ++++ 5 files changed, 638 insertions(+) create mode 100644 litellm/llms/watsonx.py diff --git a/litellm/__init__.py b/litellm/__init__.py index b9d9891ca..95dd33f1c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -298,6 +298,7 @@ aleph_alpha_models: List = [] bedrock_models: List = [] deepinfra_models: List = [] perplexity_models: List = [] +watsonx_models: List = [] for key, value in model_cost.items(): if value.get("litellm_provider") == "openai": open_ai_chat_completion_models.append(key) @@ -342,6 +343,8 @@ for key, value in model_cost.items(): deepinfra_models.append(key) elif value.get("litellm_provider") == "perplexity": perplexity_models.append(key) + elif value.get("litellm_provider") == "watsonx": + watsonx_models.append(key) # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ @@ -478,6 +481,7 @@ model_list = ( + perplexity_models + maritalk_models + vertex_language_models + + watsonx_models ) provider_list: List = [ @@ -516,6 +520,7 @@ provider_list: List = [ "cloudflare", "xinference", "fireworks_ai", + "watsonx", "custom", # custom apis ] @@ -537,6 +542,7 @@ models_by_provider: dict = { "deepinfra": deepinfra_models, "perplexity": perplexity_models, "maritalk": maritalk_models, + "watsonx": watsonx_models, } # mapping for those models which have larger equivalents @@ -650,6 +656,7 @@ from .llms.bedrock import ( ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError +from .llms.watsonx import IBMWatsonXConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 176c81d5d..8ebd2a38f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -416,6 +416,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): prompt = default_pt(messages) return prompt +### IBM Granite + +def ibm_granite_pt(messages: list): + """ + IBM's Granite models uses the template: + <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} + + See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models + """ + return custom_prompt( + messages=messages, + role_dict={ + 'system': { + 'pre_message': '<|system|>\n', + 'post_message': '\n', + }, + 'user': { + 'pre_message': '<|user|>\n', + 'post_message': '\n', + }, + 'assistant': { + 'pre_message': '<|assistant|>\n', + 'post_message': '\n', + } + } + ).strip() ### ANTHROPIC ### @@ -1327,6 +1353,24 @@ def prompt_factory( return messages elif custom_llm_provider == "azure_text": return azure_text_pt(messages=messages) + elif custom_llm_provider == "watsonx": + if "granite" in model and "chat" in model: + # granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template + return ibm_granite_pt(messages=messages) + elif "ibm-mistral" in model: + # models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template + return mistral_instruct_pt(messages=messages) + elif "meta-llama/llama-3" in model and "instruct" in model: + return custom_prompt( + role_dict={ + "system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + }, + messages=messages, + initial_prompt_value="<|begin_of_text|>", + # final_prompt_value="\n", + ) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py new file mode 100644 index 000000000..7cb45730b --- /dev/null +++ b/litellm/llms/watsonx.py @@ -0,0 +1,480 @@ +import json, types, time +from typing import Callable, Optional, Any, Union, List + +import httpx +import litellm +from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse + +from .prompt_templates import factory as ptf + +class WatsonxError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://https://us-south.ml.cloud.ibm.com" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +class IBMWatsonXConfig: + """ + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation + (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) + + Supported params for all available watsonx.ai foundational models. + + - `decoding_method` (str): One of "greedy" or "sample" + + - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. + + - `max_new_tokens` (integer): Maximum length of the generated tokens. + + - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + + - `stop_sequences` (string[]): list of strings to use as stop sequences. + + - `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point. + + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + + - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + + - `repetition_penalty` (float): token repetition penalty during text generation. + + - `stream` (bool): If True, the model will return a stream of responses. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". + + - `truncate_input_tokens` (integer): Truncate input tokens to this length. + + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + + - `random_seed` (integer): Random seed for text generation. + + - `guardrails` (bool): Enable guardrails for harmful content. + + - `guardrails_hap_params` (dict): Guardrails for harmful content. + + - `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information. + + - `concurrency_limit` (integer): Maximum number of concurrent requests. + + - `async_mode` (bool): Enable async mode. + + - `verify` (bool): Verify the SSL certificate of calls to the watsonx url. + + - `validate` (bool): Validate the model_id at initialization. + + - `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance. + + - `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with. + """ + decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior + temperature: Optional[float] = None # + min_new_tokens: Optional[int] = None + max_new_tokens: Optional[int] = litellm.max_tokens + top_k: Optional[int] = None + top_p: Optional[float] = None + random_seed: Optional[int] = None # e.g 42 + repetition_penalty: Optional[float] = None + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] + time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds) + return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False} + truncate_input_tokens: Optional[int] = None # e.g 512 + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stream: Optional[bool] = False + # other inference params + guardrails: Optional[bool] = False # enable guardrails + guardrails_hap_params: Optional[dict] = None # guardrails for harmful content + guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information + concurrency_limit: Optional[int] = 10 # max number of concurrent requests + async_mode: Optional[bool] = False # enable async mode + verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url + validate: Optional[bool] = False # validate the model_id at initialization + model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance + watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with + + def __init__( + self, + decoding_method: Optional[str] = None, + temperature: Optional[float] = None, + min_new_tokens: Optional[int] = None, + max_new_tokens: Optional[ + int + ] = litellm.max_tokens, # petals requires max tokens to be set + top_k: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + repetition_penalty: Optional[float] = None, + stop_sequences: Optional[List[str]] = None, + time_limit: Optional[int] = None, + return_options: Optional[dict] = None, + truncate_input_tokens: Optional[int] = None, + length_penalty: Optional[dict] = None, + stream: Optional[bool] = False, + guardrails: Optional[bool] = False, + guardrails_hap_params: Optional[dict] = None, + guardrails_pii_params: Optional[dict] = None, + concurrency_limit: Optional[int] = 10, + async_mode: Optional[bool] = False, + verify: Optional[Union[bool,str]] = None, + validate: Optional[bool] = False, + model_inference: Optional[object] = None, + watsonx_client: Optional[object] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + ] + + +def init_watsonx_model( + model_id: str, + url: Optional[str] = None, + api_key: Optional[str] = None, + project_id: Optional[str] = None, + space_id: Optional[str] = None, + wx_credentials: Optional[dict] = None, + region_name: Optional[str] = None, + verify: Optional[Union[bool,str]] = None, + validate: Optional[bool] = False, + watsonx_client: Optional[object] = None, + model_params: Optional[dict] = None, +): + """ + Initialize a watsonx.ai model for inference. + + Args: + + model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/' and the space_id to the deploymend space should be provided. + url (str): The URL of the watsonx.ai instance. + api_key (str): The API key for the watsonx.ai instance. + project_id (str): The project ID for the watsonx.ai instance. + space_id (str): The space ID for the deployment space. + wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance. + region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south'). + verify (bool): Whether to verify the SSL certificate of calls to the watsonx url. + validate (bool): Whether to validate the model_id at initialization. + watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client. + model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters). + """ + + from ibm_watsonx_ai import APIClient + from ibm_watsonx_ai.foundation_models import ModelInference + + + if wx_credentials is not None: + if 'apikey' not in wx_credentials and 'api_key' in wx_credentials: + wx_credentials['apikey'] = wx_credentials.pop('api_key') + if 'apikey' not in wx_credentials: + raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials") + + if url is None: + url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL") + if api_key is None: + api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY") + if project_id is None: + project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID") + if region_name is None: + region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME") + if space_id is None: + space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID") + + + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = (url, api_key, project_id, space_id, region_name) + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + params_to_check[i] = get_secret(param) + # Assign updated values back to parameters + url, api_key, project_id, space_id, region_name = params_to_check + + ### SET WATSONX URL + if url is not None or watsonx_client is not None or wx_credentials is not None: + pass + elif region_name is not None: + url = f"https://{region_name}.ml.cloud.ibm.com" + else: + raise WatsonxError( + message="Watsonx URL not set: set WX_URL env variable or in .env file", + status_code=401, + ) + if watsonx_client is not None and project_id is None: + project_id = watsonx_client.project_id + + if model_id.startswith("deployment/"): + # deployment models are passed in as 'deployment/' + assert space_id is not None, "space_id is required for deployment models" + deployment_id = '/'.join(model_id.split("/")[1:]) + model_id = None + else: + deployment_id = None + + if watsonx_client is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + api_client=watsonx_client, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + elif wx_credentials is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + credentials=wx_credentials, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + elif api_key is not None: + model = ModelInference( + model_id=model_id, + params=model_params, + credentials={ + "apikey": api_key, + "url": url, + }, + project_id=project_id, + deployment_id=deployment_id, + verify=verify, + validate=validate, + space_id=space_id, + ) + else: + raise WatsonxError(500, "WatsonX credentials not passed or could not be found.") + + + return model + + +def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): + # handle anthropic prompts and amazon titan prompts + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_dict = custom_prompt_dict[model] + prompt = ptf.custom_prompt( + messages=messages, + role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""), + final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), + bos_token=model_prompt_dict.get("bos_token", ""), + eos_token=model_prompt_dict.get("eos_token", ""), + ) + return prompt + elif provider == "ibm": + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) + elif provider == "ibm-mistralai": + prompt = ptf.mistral_instruct_pt(messages=messages) + else: + prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx') + return prompt + + +""" +IBM watsonx.ai AUTH Keys/Vars +os.environ['WX_URL'] = "" +os.environ['WX_API_KEY'] = "" +os.environ['WX_PROJECT_ID'] = "" +""" + +def completion( + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params:Optional[dict]=None, + litellm_params:Optional[dict]=None, + logger_fn=None, + timeout:float=None, +): + from ibm_watsonx_ai.foundation_models import Model, ModelInference + + try: + stream = optional_params.pop("stream", False) + extra_generate_params = dict( + guardrails=optional_params.pop("guardrails", False), + guardrails_hap_params=optional_params.pop("guardrails_hap_params", None), + guardrails_pii_params=optional_params.pop("guardrails_pii_params", None), + concurrency_limit=optional_params.pop("concurrency_limit", 10), + async_mode=optional_params.pop("async_mode", False), + ) + if timeout is not None and optional_params.get("time_limit") is None: + # the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds) + optional_params['time_limit'] = max(0, int(timeout*1000)) + extra_body_params = optional_params.pop("extra_body", {}) + optional_params.update(extra_body_params) + # LOAD CONFIG + config = IBMWatsonXConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + model_inference = optional_params.pop("model_inference", None) + if model_inference is None: + # INIT MODEL + model_client:ModelInference = init_watsonx_model( + model_id=model, + url=optional_params.pop("url", None), + api_key=optional_params.pop("api_key", None), + project_id=optional_params.pop("project_id", None), + space_id=optional_params.pop("space_id", None), + wx_credentials=optional_params.pop("wx_credentials", None), + region_name=optional_params.pop("region_name", None), + verify=optional_params.pop("verify", None), + validate=optional_params.pop("validate", False), + watsonx_client=optional_params.pop("watsonx_client", None), + model_params=optional_params, + ) + else: + model_client:ModelInference = model_inference + model = model_client.model_id + + # MAKE PROMPT + provider = model.split("/")[0] + model_name = '/'.join(model.split("/")[1:]) + prompt = convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) + ## COMPLETION CALL + if stream is True: + request_str = ( + "response = model.generate_text_stream(\n" + f"\tprompt={prompt},\n" + "\traw_response=True\n)" + ) + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + # remove params that are not needed for streaming + del extra_generate_params["async_mode"] + del extra_generate_params["concurrency_limit"] + # make generate call + response = model_client.generate_text_stream( + prompt=prompt, + raw_response=True, + **extra_generate_params + ) + return litellm.CustomStreamWrapper( + response, + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + else: + try: + ## LOGGING + request_str = ( + "response = model.generate(\n" + f"\tprompt={prompt},\n" + "\traw_response=True\n)" + ) + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = model_client.generate( + prompt=prompt, + **extra_generate_params + ) + except Exception as e: + raise WatsonxError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=json.dumps(response), + additional_args={"complete_input_dict": optional_params}, + ) + print_verbose(f"raw model_response: {response}") + ## BUILD RESPONSE OBJECT + output_text = response['results'][0]['generated_text'] + + try: + if ( + len(output_text) > 0 + and hasattr(model_response.choices[0], "message") + ): + model_response["choices"][0]["message"]["content"] = output_text + model_response["finish_reason"] = response['results'][0]['stop_reason'] + prompt_tokens = response['results'][0]['input_token_count'] + completion_tokens = response['results'][0]['generated_token_count'] + else: + raise Exception() + except: + raise WatsonxError( + message=json.dumps(output_text), + status_code=500, + ) + model_response['created'] = int(time.time()) + model_response['model'] = model_name + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response + except WatsonxError as e: + raise e + except Exception as e: + raise WatsonxError(status_code=500, message=str(e)) + + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 65696b3c0..753193f96 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -63,6 +63,7 @@ from .llms import ( vertex_ai, vertex_ai_anthropic, maritalk, + watsonx, ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion @@ -1858,6 +1859,43 @@ def completion( ## RESPONSE OBJECT response = response + elif custom_llm_provider == "watsonx": + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = watsonx.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + timeout=timeout, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="watsonx", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + ## RESPONSE OBJECT + response = response elif custom_llm_provider == "vllm": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = vllm.completion( diff --git a/litellm/utils.py b/litellm/utils.py index e230675e6..19118acbe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5331,6 +5331,45 @@ def get_optional_params( optional_params["extra_body"] = ( extra_body # openai client supports `extra_body` param ) + elif custom_llm_provider == "watsonx": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + if max_tokens is not None: + optional_params["max_new_tokens"] = max_tokens + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if frequency_penalty is not None: + optional_params["repetition_penalty"] = frequency_penalty + if seed is not None: + optional_params["random_seed"] = seed + if stop is not None: + optional_params["stop_sequences"] = stop + + # WatsonX-only parameters + extra_body = {} + if "decoding_method" in passed_params: + extra_body["decoding_method"] = passed_params.pop("decoding_method") + if "min_tokens" in passed_params or "min_new_tokens" in passed_params: + extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens")) + if "top_k" in passed_params: + extra_body["top_k"] = passed_params.pop("top_k") + if "truncate_input_tokens" in passed_params: + extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens") + if "length_penalty" in passed_params: + extra_body["length_penalty"] = passed_params.pop("length_penalty") + if "time_limit" in passed_params: + extra_body["time_limit"] = passed_params.pop("time_limit") + if "return_options" in passed_params: + extra_body["return_options"] = passed_params.pop("return_options") + optional_params["extra_body"] = ( + extra_body # openai client supports `extra_body` param + ) else: # assume passing in params for openai/azure openai print_verbose( f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" @@ -5688,6 +5727,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "frequency_penalty", "presence_penalty", ] + elif custom_llm_provider == "watsonx": + return litellm.IBMWatsonXConfig().get_supported_openai_params() def get_formatted_prompt( @@ -5914,6 +5955,8 @@ def get_llm_provider( model in litellm.bedrock_models or model in litellm.bedrock_embedding_models ): custom_llm_provider = "bedrock" + elif model in litellm.watsonx_models: + custom_llm_provider = "watsonx" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" @@ -9590,6 +9633,26 @@ class CustomStreamWrapper: "is_finished": chunk["is_finished"], "finish_reason": finish_reason, } + + def handle_watsonx_stream(self, chunk): + try: + if isinstance(chunk, dict): + pass + elif isinstance(chunk, str): + chunk = json.loads(chunk) + result = chunk.get("results", []) + if len(result) > 0: + text = result[0].get("generated_text", "") + finish_reason = result[0].get("stop_reason") + is_finished = finish_reason != 'not_finished' + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + return "" + except Exception as e: + raise e def model_response_creator(self): model_response = ModelResponse(stream=True, model=self.model) @@ -9845,6 +9908,12 @@ class CustomStreamWrapper: print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "watsonx": + response_obj = self.handle_watsonx_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] From ca0807d8ab723b881cda6c55a9168dbb1f5f2af4 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Sat, 20 Apr 2024 20:52:25 +0200 Subject: [PATCH 02/11] (docs) added watsonx cookbook --- cookbook/liteLLM_IBM_Watsonx.ipynb | 213 +++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 cookbook/liteLLM_IBM_Watsonx.ipynb diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb new file mode 100644 index 000000000..e62ec9c8c --- /dev/null +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LiteLLM x IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai)\n", + "\n", + "Note: For watsonx.ai requests you need to ensure you have `ibm-watsonx-ai` installed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-Requisites" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install litellm\n", + "!pip install ibm-watsonx-ai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set watsonx Credentials\n", + "\n", + "See [this documentation](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-credentials.html?context=wx) for more information about authenticating to watsonx.ai" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"WX_URL\"] = \"\" # Your watsonx.ai base URL\n", + "os.environ[\"WX_API_KEY\"] = \"\" # Your IBM cloud API key or watsonx.ai token\n", + "os.environ[\"WX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example Requests" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Granite v2 response:\n", + "ModelResponse(id='chatcmpl-afe4e875-2cfb-4e8c-aba5-36853007aaae', choices=[Choices(finish_reason='stop', index=0, message=Message(content=' I\\'m looking for a way to extract the email addresses from a CSV file. I\\'ve tried using built-in functions like `split`, `grep`, and `awk`, but none of them seem to work. Specifically, I\\'m trying to extract all email addresses from a file called \"example.csv\". Here\\'s what I have so far:\\n```bash\\ngrep -oP \"[\\\\w-]+@[a-z0-9-]+\\\\.[a-z]{2,}$\" example.csv > extracted_emails.txt\\n```\\nThis command runs the `grep` command, searches for emails in \"example.csv\", and saves the results to a new file called \"extracted\\\\_emails.txt\". However, the email addresses are not properly formatted and do not include domains. I think there might be a better way to do this, so I\\'m open to suggestions.\\n\\nAny help or guidance would be greatly appreciated.\\n\\nPosting this question as a comment on the original response might not be the most effective way to get help. If it\\'s possible, I can create a Code Review question here instead.\\n(Original post here: Date: Tue, 23 Apr 2024 11:53:38 +0200 Subject: [PATCH 03/11] feat - watsonx refractoring, removed dependency, and added support for embedding calls --- litellm/__init__.py | 2 +- litellm/llms/watsonx.py | 792 ++++++++++++++++++++++------------------ litellm/main.py | 11 +- litellm/utils.py | 38 +- 4 files changed, 477 insertions(+), 366 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 95dd33f1c..a7c17d53c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -656,7 +656,7 @@ from .llms.bedrock import ( ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError -from .llms.watsonx import IBMWatsonXConfig +from .llms.watsonx import IBMWatsonXAIConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 7cb45730b..38837ddb2 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,27 +1,31 @@ -import json, types, time -from typing import Callable, Optional, Any, Union, List +import json, types, time # noqa: E401 +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Any, Union, List import httpx +import requests import litellm -from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse +from litellm.utils import ModelResponse, get_secret, Usage +from .base import BaseLLM from .prompt_templates import factory as ptf -class WatsonxError(Exception): - def __init__(self, status_code, message): + +class WatsonXAIError(Exception): + def __init__(self, status_code, message, url: str = None): self.status_code = status_code self.message = message - self.request = httpx.Request( - method="POST", url="https://https://us-south.ml.cloud.ibm.com" - ) + url = url or "https://https://us-south.ml.cloud.ibm.com" + self.request = httpx.Request(method="POST", url=url) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class IBMWatsonXConfig: + +class IBMWatsonXAIConfig: """ - Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) Supported params for all available watsonx.ai foundational models. @@ -34,96 +38,64 @@ class IBMWatsonXConfig: - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + - `stop_sequences` (string[]): list of strings to use as stop sequences. - - `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point. - - - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. - - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + - `repetition_penalty` (float): token repetition penalty during text generation. - - `stream` (bool): If True, the model will return a stream of responses. - - - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". - - `truncate_input_tokens` (integer): Truncate input tokens to this length. - - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. - `random_seed` (integer): Random seed for text generation. - - `guardrails` (bool): Enable guardrails for harmful content. + - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. - - `guardrails_hap_params` (dict): Guardrails for harmful content. - - - `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information. - - - `concurrency_limit` (integer): Maximum number of concurrent requests. - - - `async_mode` (bool): Enable async mode. - - - `verify` (bool): Verify the SSL certificate of calls to the watsonx url. - - - `validate` (bool): Validate the model_id at initialization. - - - `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance. - - - `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with. + - `stream` (bool): If True, the model will return a stream of responses. """ - decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior - temperature: Optional[float] = None # + + decoding_method: Optional[str] = "sample" + temperature: Optional[float] = None + max_new_tokens: Optional[int] = None # litellm.max_tokens min_new_tokens: Optional[int] = None - max_new_tokens: Optional[int] = litellm.max_tokens + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] top_k: Optional[int] = None top_p: Optional[float] = None - random_seed: Optional[int] = None # e.g 42 repetition_penalty: Optional[float] = None - stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] - time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds) - return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False} - truncate_input_tokens: Optional[int] = None # e.g 512 - length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + truncate_input_tokens: Optional[int] = None + include_stop_sequences: Optional[bool] = False + return_options: Optional[dict] = None + return_options: Optional[Dict[str, bool]] = None + random_seed: Optional[int] = None # e.g 42 + moderations: Optional[dict] = None stream: Optional[bool] = False - # other inference params - guardrails: Optional[bool] = False # enable guardrails - guardrails_hap_params: Optional[dict] = None # guardrails for harmful content - guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information - concurrency_limit: Optional[int] = 10 # max number of concurrent requests - async_mode: Optional[bool] = False # enable async mode - verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url - validate: Optional[bool] = False # validate the model_id at initialization - model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance - watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with def __init__( self, decoding_method: Optional[str] = None, temperature: Optional[float] = None, + max_new_tokens: Optional[int] = None, min_new_tokens: Optional[int] = None, - max_new_tokens: Optional[ - int - ] = litellm.max_tokens, # petals requires max tokens to be set + length_penalty: Optional[dict] = None, + stop_sequences: Optional[List[str]] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - random_seed: Optional[int] = None, repetition_penalty: Optional[float] = None, - stop_sequences: Optional[List[str]] = None, - time_limit: Optional[int] = None, - return_options: Optional[dict] = None, truncate_input_tokens: Optional[int] = None, - length_penalty: Optional[dict] = None, - stream: Optional[bool] = False, - guardrails: Optional[bool] = False, - guardrails_hap_params: Optional[dict] = None, - guardrails_pii_params: Optional[dict] = None, - concurrency_limit: Optional[int] = 10, - async_mode: Optional[bool] = False, - verify: Optional[Union[bool,str]] = None, - validate: Optional[bool] = False, - model_inference: Optional[object] = None, - watsonx_client: Optional[object] = None, + include_stop_sequences: Optional[bool] = None, + return_options: Optional[dict] = None, + random_seed: Optional[int] = None, + moderations: Optional[dict] = None, + stream: Optional[bool] = None, + **kwargs, ) -> None: locals_ = locals() for key, value in locals_.items(): @@ -150,143 +122,16 @@ class IBMWatsonXConfig: def get_supported_openai_params(self): return [ - "temperature", # equivalent to temperature - "max_tokens", # equivalent to max_new_tokens - "top_p", # equivalent to top_p - "frequency_penalty", # equivalent to repetition_penalty - "stop", # equivalent to stop_sequences - "seed", # equivalent to random_seed - "stream", # equivalent to stream + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream ] -def init_watsonx_model( - model_id: str, - url: Optional[str] = None, - api_key: Optional[str] = None, - project_id: Optional[str] = None, - space_id: Optional[str] = None, - wx_credentials: Optional[dict] = None, - region_name: Optional[str] = None, - verify: Optional[Union[bool,str]] = None, - validate: Optional[bool] = False, - watsonx_client: Optional[object] = None, - model_params: Optional[dict] = None, -): - """ - Initialize a watsonx.ai model for inference. - - Args: - - model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/' and the space_id to the deploymend space should be provided. - url (str): The URL of the watsonx.ai instance. - api_key (str): The API key for the watsonx.ai instance. - project_id (str): The project ID for the watsonx.ai instance. - space_id (str): The space ID for the deployment space. - wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance. - region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south'). - verify (bool): Whether to verify the SSL certificate of calls to the watsonx url. - validate (bool): Whether to validate the model_id at initialization. - watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client. - model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters). - """ - - from ibm_watsonx_ai import APIClient - from ibm_watsonx_ai.foundation_models import ModelInference - - - if wx_credentials is not None: - if 'apikey' not in wx_credentials and 'api_key' in wx_credentials: - wx_credentials['apikey'] = wx_credentials.pop('api_key') - if 'apikey' not in wx_credentials: - raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials") - - if url is None: - url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL") - if api_key is None: - api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY") - if project_id is None: - project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID") - if region_name is None: - region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME") - if space_id is None: - space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID") - - - ## CHECK IS 'os.environ/' passed in - # Define the list of parameters to check - params_to_check = (url, api_key, project_id, space_id, region_name) - # Iterate over parameters and update if needed - for i, param in enumerate(params_to_check): - if param and param.startswith("os.environ/"): - params_to_check[i] = get_secret(param) - # Assign updated values back to parameters - url, api_key, project_id, space_id, region_name = params_to_check - - ### SET WATSONX URL - if url is not None or watsonx_client is not None or wx_credentials is not None: - pass - elif region_name is not None: - url = f"https://{region_name}.ml.cloud.ibm.com" - else: - raise WatsonxError( - message="Watsonx URL not set: set WX_URL env variable or in .env file", - status_code=401, - ) - if watsonx_client is not None and project_id is None: - project_id = watsonx_client.project_id - - if model_id.startswith("deployment/"): - # deployment models are passed in as 'deployment/' - assert space_id is not None, "space_id is required for deployment models" - deployment_id = '/'.join(model_id.split("/")[1:]) - model_id = None - else: - deployment_id = None - - if watsonx_client is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - api_client=watsonx_client, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - elif wx_credentials is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - credentials=wx_credentials, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - elif api_key is not None: - model = ModelInference( - model_id=model_id, - params=model_params, - credentials={ - "apikey": api_key, - "url": url, - }, - project_id=project_id, - deployment_id=deployment_id, - verify=verify, - validate=validate, - space_id=space_id, - ) - else: - raise WatsonxError(500, "WatsonX credentials not passed or could not be found.") - - - return model - - def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): # handle anthropic prompts and amazon titan prompts if model in custom_prompt_dict: @@ -294,8 +139,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): model_prompt_dict = custom_prompt_dict[model] prompt = ptf.custom_prompt( messages=messages, - role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")), - initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""), + role_dict=model_prompt_dict.get( + "role_dict", model_prompt_dict.get("roles") + ), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), bos_token=model_prompt_dict.get("bos_token", ""), eos_token=model_prompt_dict.get("eos_token", ""), @@ -308,173 +155,408 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): elif provider == "ibm-mistralai": prompt = ptf.mistral_instruct_pt(messages=messages) else: - prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx') + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) return prompt -""" -IBM watsonx.ai AUTH Keys/Vars -os.environ['WX_URL'] = "" -os.environ['WX_API_KEY'] = "" -os.environ['WX_PROJECT_ID'] = "" -""" +class IBMWatsonXAI(BaseLLM): + """ + Class to interface with IBM Watsonx.ai API for text generation and embeddings. -def completion( - model: str, - messages: list, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params:Optional[dict]=None, - litellm_params:Optional[dict]=None, - logger_fn=None, - timeout:float=None, -): - from ibm_watsonx_ai.foundation_models import Model, ModelInference + Reference: https://cloud.ibm.com/apidocs/watsonx-ai + """ - try: - stream = optional_params.pop("stream", False) - extra_generate_params = dict( - guardrails=optional_params.pop("guardrails", False), - guardrails_hap_params=optional_params.pop("guardrails_hap_params", None), - guardrails_pii_params=optional_params.pop("guardrails_pii_params", None), - concurrency_limit=optional_params.pop("concurrency_limit", 10), - async_mode=optional_params.pop("async_mode", False), - ) - if timeout is not None and optional_params.get("time_limit") is None: - # the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds) - optional_params['time_limit'] = max(0, int(timeout*1000)) + api_version = "2024-03-13" + _text_gen_endpoint = "/ml/v1/text/generation" + _text_gen_stream_endpoint = "/ml/v1/text/generation_stream" + _deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation" + _deployment_text_gen_stream_endpoint = ( + "/ml/v1/deployments/{deployment_id}/text/generation_stream" + ) + _embeddings_endpoint = "/ml/v1/text/embeddings" + _prompts_endpoint = "/ml/v1/prompts" + + def __init__(self) -> None: + super().__init__() + + def _prepare_text_generation_req( + self, + model_id: str, + prompt: str, + stream: bool, + optional_params: dict, + print_verbose: Callable = None, + ) -> httpx.Request: + """ + Get the request parameters for text generation. + """ + api_params = self._get_api_params(optional_params, print_verbose=print_verbose) + # build auth headers + api_token = api_params.get("token") + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } extra_body_params = optional_params.pop("extra_body", {}) optional_params.update(extra_body_params) - # LOAD CONFIG - config = IBMWatsonXConfig.get_config() + # init the payload to the text generation call + payload = { + "input": prompt, + "moderations": optional_params.pop("moderations", {}), + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + # text generation endpoint deployment or model / stream or not + if model_id.startswith("deployment/"): + # deployment models are passed in as 'deployment/' + if api_params.get("space_id") is None: + raise WatsonXAIError( + status_code=401, + url=api_params["url"], + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + deployment_id = "/".join(model_id.split("/")[1:]) + endpoint = ( + self._deployment_text_gen_stream_endpoint + if stream + else self._deployment_text_gen_endpoint + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + payload["model_id"] = model_id + payload["project_id"] = api_params["project_id"] + endpoint = ( + self._text_gen_stream_endpoint if stream else self._text_gen_endpoint + ) + url = api_params["url"].rstrip("/") + endpoint + return httpx.Request( + "POST", url, headers=headers, json=payload, params=request_params + ) + + def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: + """ + Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. + """ + # Load auth variables from params + url = params.pop("url", None) + api_key = params.pop("apikey", None) + token = params.pop("token", None) + project_id = params.pop("project_id", None) # watsonx.ai project_id + space_id = params.pop("space_id", None) # watsonx.ai deployment space_id + region_name = params.pop("region_name", params.pop("region", None)) + wx_credentials = params.pop("wx_credentials", None) + api_version = params.pop("api_version", IBMWatsonXAI.api_version) + # Load auth variables from environment variables + if url is None: + url = ( + get_secret("WATSONX_URL") + or get_secret("WX_URL") + or get_secret("WML_URL") + ) + if api_key is None: + api_key = get_secret("WATSONX_API_KEY") or get_secret("WX_API_KEY") + if token is None: + token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") + if project_id is None: + project_id = ( + get_secret("WATSONX_PROJECT_ID") + or get_secret("WX_PROJECT_ID") + or get_secret("PROJECT_ID") + ) + if region_name is None: + region_name = ( + get_secret("WATSONX_REGION") + or get_secret("WX_REGION") + or get_secret("REGION") + ) + if space_id is None: + space_id = ( + get_secret("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret("WATSONX_SPACE_ID") + or get_secret("WX_SPACE_ID") + or get_secret("SPACE_ID") + ) + + # credentials parsing + if wx_credentials is not None: + url = wx_credentials.get("url", url) + api_key = wx_credentials.get( + "apikey", wx_credentials.get("api_key", api_key) + ) + token = wx_credentials.get("token", token) + + # verify that all required credentials are present + if url is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.", + ) + if token is None and api_key is not None: + # generate the auth token + if print_verbose: + print_verbose("Generating IAM token for Watsonx.ai") + token = self.generate_iam_token(api_key) + elif token is None and api_key is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.", + ) + if project_id is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + ) + + return { + "url": url, + "api_key": api_key, + "token": token, + "project_id": project_id, + "space_id": space_id, + "region_name": region_name, + "api_version": api_version, + } + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: Optional[dict] = None, + litellm_params: Optional[dict] = None, + logger_fn=None, + timeout: float = None, + ): + """ + Send a text generation request to the IBM Watsonx.ai API. + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + """ + stream = optional_params.pop("stream", False) + + # Load default configs + config = IBMWatsonXAIConfig.get_config() for k, v in config.items(): if k not in optional_params: optional_params[k] = v - model_inference = optional_params.pop("model_inference", None) - if model_inference is None: - # INIT MODEL - model_client:ModelInference = init_watsonx_model( - model_id=model, - url=optional_params.pop("url", None), - api_key=optional_params.pop("api_key", None), - project_id=optional_params.pop("project_id", None), - space_id=optional_params.pop("space_id", None), - wx_credentials=optional_params.pop("wx_credentials", None), - region_name=optional_params.pop("region_name", None), - verify=optional_params.pop("verify", None), - validate=optional_params.pop("validate", False), - watsonx_client=optional_params.pop("watsonx_client", None), - model_params=optional_params, - ) - else: - model_client:ModelInference = model_inference - model = model_client.model_id - - # MAKE PROMPT + # Make prompt to send to model provider = model.split("/")[0] - model_name = '/'.join(model.split("/")[1:]) + # model_name = "/".join(model.split("/")[1:]) prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) - ## COMPLETION CALL - if stream is True: - request_str = ( - "response = model.generate_text_stream(\n" - f"\tprompt={prompt},\n" - "\traw_response=True\n)" - ) - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - # remove params that are not needed for streaming - del extra_generate_params["async_mode"] - del extra_generate_params["concurrency_limit"] - # make generate call - response = model_client.generate_text_stream( - prompt=prompt, - raw_response=True, - **extra_generate_params - ) - return litellm.CustomStreamWrapper( - response, - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) - else: - try: - ## LOGGING - request_str = ( - "response = model.generate(\n" - f"\tprompt={prompt},\n" - "\traw_response=True\n)" - ) - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - response = model_client.generate( - prompt=prompt, - **extra_generate_params - ) - except Exception as e: - raise WatsonxError(status_code=500, message=str(e)) - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=json.dumps(response), - additional_args={"complete_input_dict": optional_params}, - ) - print_verbose(f"raw model_response: {response}") - ## BUILD RESPONSE OBJECT - output_text = response['results'][0]['generated_text'] + def process_text_request(request: httpx.Request) -> ModelResponse: + with self._manage_response( + request, logging_obj=logging_obj, input=prompt, timeout=timeout + ) as resp: + json_resp = resp.json() + + generated_text = json_resp["results"][0]["generated_text"] + prompt_tokens = json_resp["results"][0]["input_token_count"] + completion_tokens = json_resp["results"][0]["generated_token_count"] + model_response["choices"][0]["message"]["content"] = generated_text + model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] + model_response["created"] = int(time.time()) + model_response["model"] = model + model_response.usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + return model_response + + def process_stream_request( + request: httpx.Request, + ) -> litellm.CustomStreamWrapper: + # stream the response - generated chunks will be handled + # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream + with self._manage_response( + request, + logging_obj=logging_obj, + stream=True, + input=prompt, + timeout=timeout, + ) as resp: + response = litellm.CustomStreamWrapper( + resp.iter_lines(), + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return response try: - if ( - len(output_text) > 0 - and hasattr(model_response.choices[0], "message") - ): - model_response["choices"][0]["message"]["content"] = output_text - model_response["finish_reason"] = response['results'][0]['stop_reason'] - prompt_tokens = response['results'][0]['input_token_count'] - completion_tokens = response['results'][0]['generated_token_count'] - else: - raise Exception() - except: - raise WatsonxError( - message=json.dumps(output_text), - status_code=500, + ## Get the response from the model + request = self._prepare_text_generation_req( + model_id=model, + prompt=prompt, + stream=stream, + optional_params=optional_params, + print_verbose=print_verbose, ) - model_response['created'] = int(time.time()) - model_response['model'] = model_name - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + if stream: + return process_stream_request(request) + else: + return process_text_request(request) + except WatsonXAIError as e: + raise e + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + + def embedding( + self, + model: str, + input: Union[list, str], + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + ): + """ + Send a text embedding request to the IBM Watsonx.ai API. + """ + if optional_params is None: + optional_params = {} + # Load default configs + config = IBMWatsonXAIConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + # Load auth variables from environment variables + if isinstance(input, str): + input = [input] + if api_key is not None: + optional_params["api_key"] = api_key + api_params = self._get_api_params(optional_params) + # build auth headers + api_token = api_params.get("token") + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + # init the payload to the text generation call + payload = { + "inputs": input, + "model_id": model, + "project_id": api_params["project_id"], + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + url = api_params["url"].rstrip("/") + self._embeddings_endpoint + request = httpx.Request( + "POST", url, headers=headers, json=payload, params=request_params + ) + with self._manage_response( + request, logging_obj=logging_obj, input=input + ) as resp: + json_resp = resp.json() + + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens ) - model_response.usage = usage return model_response - except WatsonxError as e: - raise e - except Exception as e: - raise WatsonxError(status_code=500, message=str(e)) + def generate_iam_token(self, api_key=None, **params): + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + if api_key is None: + api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY") + if api_key is None: + raise ValueError("API key is required") + headers["Accept"] = "application/json" + data = { + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + } + response = httpx.post( + "https://iam.cloud.ibm.com/identity/token", data=data, headers=headers + ) + response.raise_for_status() + json_data = response.json() + iam_access_token = json_data["access_token"] + self.token = iam_access_token + return iam_access_token -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass \ No newline at end of file + @contextmanager + def _manage_response( + self, + request: httpx.Request, + logging_obj: Any, + stream: bool = False, + input: Optional[Any] = None, + timeout: float = None, + ): + request_str = ( + f"response = {request.method}(\n" + f"\turl={request.url},\n" + f"\tjson={request.content.decode()},\n" + f")" + ) + json_input = json.loads(request.content.decode()) + headers = dict(request.headers) + logging_obj.pre_call( + input=input, + api_key=request.headers.get("Authorization"), + additional_args={ + "complete_input_dict": json_input, + "request_str": request_str, + }, + ) + try: + if stream: + resp = requests.request( + method=request.method, + url=str(request.url), + headers=headers, + json=json_input, + stream=True, + timeout=timeout, + ) + # resp.raise_for_status() + yield resp + else: + resp = requests.request( + method=request.method, + url=str(request.url), + headers=headers, + json=json_input, + timeout=timeout, + ) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + logging_obj.post_call( + input=input, + api_key=request.headers.get("Authorization"), + original_response=json.dumps(resp.json()), + additional_args={ + "status_code": resp.status_code, + "complete_input_dict": request, + }, + ) diff --git a/litellm/main.py b/litellm/main.py index b61df8c12..8f357b834 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1862,7 +1862,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonx.completion( + response = watsonx.IBMWatsonXAI().completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -2976,6 +2976,15 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "watsonx": + response = watsonx.IBMWatsonXAI().embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/utils.py b/litellm/utils.py index 836587fb1..89061c3bf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "presence_penalty", ] elif custom_llm_provider == "watsonx": - return litellm.IBMWatsonXConfig().get_supported_openai_params() + return litellm.IBMWatsonXAIConfig().get_supported_openai_params() def get_formatted_prompt( @@ -9682,20 +9682,31 @@ class CustomStreamWrapper: def handle_watsonx_stream(self, chunk): try: if isinstance(chunk, dict): - pass - elif isinstance(chunk, str): - chunk = json.loads(chunk) - result = chunk.get("results", []) - if len(result) > 0: - text = result[0].get("generated_text", "") - finish_reason = result[0].get("stop_reason") + parsed_response = chunk + elif isinstance(chunk, (str, bytes)): + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if 'generated_text' in chunk: + response = chunk.replace('data: ', '').strip() + parsed_response = json.loads(response) + else: + return {"text": "", "is_finished": False} + else: + print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") + raise ValueError(f"Unable to parse response. Original response: {chunk}") + results = parsed_response.get("results", []) + if len(results) > 0: + text = results[0].get("generated_text", "") + finish_reason = results[0].get("stop_reason") is_finished = finish_reason != 'not_finished' return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, + "prompt_tokens": results[0].get("input_token_count", None), + "completion_tokens": results[0].get("generated_token_count", None), } - return "" + return {"text": "", "is_finished": False} except Exception as e: raise e @@ -9957,6 +9968,15 @@ class CustomStreamWrapper: response_obj = self.handle_watsonx_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj.get("prompt_tokens") is not None: + prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) + model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) + if response_obj.get("completion_tokens") is not None: + model_response.usage.completion_tokens = response_obj["completion_tokens"] + model_response.usage.total_tokens = ( + getattr(model_response.usage, "prompt_tokens", 0) + + getattr(model_response.usage, "completion_tokens", 0) + ) if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": From 7cbe9835c9a5c9e599e7c9ba821b63e5f06ce748 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Tue, 23 Apr 2024 11:59:22 +0200 Subject: [PATCH 04/11] (docs) updated litellm watsonx cookbook --- cookbook/liteLLM_IBM_Watsonx.ipynb | 144 +++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 29 deletions(-) diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb index e62ec9c8c..99854b3b3 100644 --- a/cookbook/liteLLM_IBM_Watsonx.ipynb +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -4,9 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# LiteLLM x IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai)\n", - "\n", - "Note: For watsonx.ai requests you need to ensure you have `ibm-watsonx-ai` installed." + "# LiteLLM x IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai)" ] }, { @@ -22,8 +20,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install litellm\n", - "!pip install ibm-watsonx-ai" + "!pip install litellm" ] }, { @@ -32,7 +29,7 @@ "source": [ "## Set watsonx Credentials\n", "\n", - "See [this documentation](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-credentials.html?context=wx) for more information about authenticating to watsonx.ai" + "See [this documentation](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information about authenticating to watsonx.ai" ] }, { @@ -42,22 +39,34 @@ "outputs": [], "source": [ "import os\n", + "import litellm\n", + "from litellm.llms.watsonx import IBMWatsonXAI\n", + "litellm.set_verbose = False\n", "\n", "os.environ[\"WX_URL\"] = \"\" # Your watsonx.ai base URL\n", "os.environ[\"WX_API_KEY\"] = \"\" # Your IBM cloud API key or watsonx.ai token\n", - "os.environ[\"WX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project" + "os.environ[\"WX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project\n", + "\n", + "# generating an IAM token is optional, but it is recommended to generate it once and use it for all your requests during the session\n", + "# if not passed to the function, it will be generated automatically for each request\n", + "iam_token = IBMWatsonXAI().generate_iam_token(api_key=os.environ[\"WATSONX_API_KEY\"]) \n", + "# you can also set os.environ[\"WATSONX_TOKEN\"] = iam_token" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Example Requests" + "## Completion Requests\n", + "\n", + "See the following link for a list of supported *text generation* models available with watsonx.ai:\n", + "\n", + "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -65,18 +74,20 @@ "output_type": "stream", "text": [ "Granite v2 response:\n", - "ModelResponse(id='chatcmpl-afe4e875-2cfb-4e8c-aba5-36853007aaae', choices=[Choices(finish_reason='stop', index=0, message=Message(content=' I\\'m looking for a way to extract the email addresses from a CSV file. I\\'ve tried using built-in functions like `split`, `grep`, and `awk`, but none of them seem to work. Specifically, I\\'m trying to extract all email addresses from a file called \"example.csv\". Here\\'s what I have so far:\\n```bash\\ngrep -oP \"[\\\\w-]+@[a-z0-9-]+\\\\.[a-z]{2,}$\" example.csv > extracted_emails.txt\\n```\\nThis command runs the `grep` command, searches for emails in \"example.csv\", and saves the results to a new file called \"extracted\\\\_emails.txt\". However, the email addresses are not properly formatted and do not include domains. I think there might be a better way to do this, so I\\'m open to suggestions.\\n\\nAny help or guidance would be greatly appreciated.\\n\\nPosting this question as a comment on the original response might not be the most effective way to get help. If it\\'s possible, I can create a Code Review question here instead.\\n(Original post here: \" format (where `` is the ID of the deployed model in the deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from litellm import acompletion\n", + "\n", + "os.environ[\"WATSONX_DEPLOYMENT_SPACE_ID\"] = \"\" # ID of the watsonx.ai deployment space where the model is deployed\n", + "await acompletion(\n", + " model=\"watsonx/deployment/\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " token=iam_token\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Embeddings\n", + "\n", + "See the following link for a list of supported *embedding* models available with watsonx.ai:\n", + "\n", + "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Slate 30m embeddings response:\n", + "EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [0.0025110552, -0.021022381, 0.056658838, 0.023194756, 0.06528087, 0.051285733, 0.025715597, 0.009245981, -0.048218597, 0.02131204, 0.0048608365, 0.056427978, -0.029722512, -0.022280851, 0.03397489, 0.15861669, -0.0032172804, 0.021461686, -0.034179244, 0.03242367, 0.045696042, -0.10642838, 0.044042706, 0.003619815, -0.03445944, 0.06782116, -0.012801977, -0.083491564, 0.048063237, -0.0009263491, 0.03926016, -0.003800945, 0.06431806, 0.008804617, 0.041459076, 0.019176882, 0.063215, 0.016872335, -0.07120825, 0.0026858407, -0.0061372668, 0.016006729, 0.034623176, -0.0009702338, 0.05586387, -0.0030038806, 0.10219119, 0.023867028, 0.017003942, 0.07522453, 0.03827543, 0.002119465, -0.047579825, 0.030801363, 0.055104297, -0.00926156, 0.060950216, -0.012564041, -0.0938483, 0.06749232, 0.0303093, 0.1260211, 0.008772238, 0.0937941, 0.03146898, -0.013548525, -0.04654987, 0.038247738, -0.0047283196, -0.021979854, -0.04481472, 0.009184976, 0.030558616, -0.035239127, 0.015711905, 0.079948395, -0.10273533, -0.033666693, 0.009253284, -0.013218568, 0.014513645, 0.011746366, -0.04836566, 0.00059039996, 0.056465007, 0.057913274, 0.046911363, 0.022496173, -0.016504057, -0.0009266135, 0.007562665, 0.024523543, 0.012681347, -0.0034720704, 0.014897689, 0.034027215, -0.035149213, 0.046610955, -0.38038146, -0.05560348, 0.056164417, 0.023633359, -0.020914413, 0.0017839101, 0.043425612, 0.0921522, 0.021333266, 0.032627117, 0.052366074, 0.059688427, -0.02425017, 0.07460727, 0.040419403, 0.018662684, -0.02174095, -0.015262358, 0.0041535227, -0.004320668, 0.001545062, 0.023696192, 0.053526532, 0.031027582, -0.030727778, -0.07266011, 0.01924883, -0.021610625, 0.03179455, -0.002117363, 0.037670195, -0.021235954, -0.03931032, -0.057163127, -0.046020538, 0.013852293, 0.007136301, 0.020461356, 0.027465757, 0.013625788, 0.09281521, 0.03537469, -0.15295835, -0.045262642, 0.013799362, 0.029831719, 0.06360841, 0.045387108, -0.008106462, 0.047562532, 0.026519125, 0.030519808, -0.035604805, 0.059504308, -0.010260606, 0.05920231, -0.039987702, 0.003475537, 0.012535757, 0.03711557, 0.022637982, 0.022368006, -0.013918498, 0.03144229, 0.02680179, 0.05283082, 0.09737034, 0.062140185, 0.047479317, 0.04292394, 0.041657448, 0.031671192, -0.01198203, -0.0398639, 0.050961364, -0.005440624, -0.013748672, 0.02486566, 0.06105261, 0.09158345, 0.047486037, 0.03503525, -0.0009857323, 0.017584834, 0.0015176772, -0.013855697, -0.0016783233, -0.032760657, 0.0073869363, 0.0032070065, 0.08748817, 0.062042974, -0.006563574, -0.01277716, 0.064277925, -0.048509046, 0.01998247, 0.015449057, 0.06161844, 0.0361277, 0.07378269, 0.031909943, 0.035593968, -0.021533003, 0.15151453, 0.009489467, 0.0077385777, 0.004732935, 0.06757376, 0.018628953, 0.03609718, 0.065334365, 0.046664603, 0.03710433, 0.023046834, 0.065034136, 0.021973003, 0.01938253, 0.0049545416, 0.009443422, 0.08657203, -0.006455585, 0.06113277, -0.009921393, 0.008861325, 0.021925068, 0.0073863543, 0.029231662, 0.018063372, -0.028237753, 0.06752595, -0.015746683, -0.06744447, -0.0019776542, -0.16144808, 0.055144247, -0.07052258, -0.0062173936, 0.005187277, 0.057623632, 0.008336536, 0.018794686, 0.08856226, 0.05324669, 0.023925344, -0.011277585, -0.015746504, -0.01888707, -0.010619123, 0.05960752, -0.02111604, 0.13263386, 0.053238407, 0.0423469, 0.03247613, 0.072818235, 0.039493106, -0.0080635715, 0.038805183, 0.05633994, 0.021095807, -0.022528276, 0.113213256, -0.040802993, 0.01971789, 0.00073800184, 0.04653605, 0.024364496, 0.051224973, 0.022803178, 0.06527072, -0.030100288, 0.02277551, 0.034268156, -0.0024341822, 0.030275142, -0.0043326514, 0.026949842, 0.03554525, 0.043582354, 0.037845742, 0.024644673, 0.06225431, 0.06668994, 0.042802095, -0.14308476, 0.028445719, -0.0057268543, 0.034851402, 0.04973769, -0.01673276, -0.0084733, -0.04498498, -0.01888843, 0.0018199912, -0.08666151, 0.03408551, 0.03374362, 0.016341621, -0.017816868, 0.027611718, 0.048712954, 0.03562084, 0.06156702, 0.06942091, 0.018424997, 0.010069236, -0.025854982, -0.005099922, 0.042129293, -0.018960087, -0.04267046, 0.003192464, 0.07610024, 0.01623567, 0.06430824, 0.045628317, -0.13192567, 0.00597194, 0.03359213, -0.051644783, -0.027538724, 0.047537625, 0.00078535493, -0.050269134, 0.06352181, 0.04414142, -0.00025181545, -0.011166945, 0.083493516, -0.022445189, 0.06386556, 0.009009819, 0.018880796, 0.046981215, -0.04803033, 0.20140722, 0.009405448, 0.011427641, 0.032028355, -0.039911997, 0.059231583, 0.10603366, -0.012695404, -0.018773954, 0.051107403, 0.004720434, 0.049031533, 0.008848073, -0.008443017, 0.068459414, -0.001594059, -0.037717424, 0.0083658025, 0.036570624, -0.009189262, -0.07422237, -0.03578154, 0.00016998129, -0.033594534, 0.04550856, -0.09751915, 0.031381045, -0.020289807, -0.025066, 0.05559659, 0.065852426, -0.030574895, 0.098877095, 0.024548644, 0.02716826, -0.0073690503, -0.006680294, -0.062504984, 0.001748584, -0.0015254011, 0.0030000636, 0.05166639, -0.03598367, 0.02785021, 0.019170346, -0.01893702, 0.006487694, -0.045320857, -0.042290565, 0.030072719]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))\n", + "Slate 125m embeddings response:\n", + "EmbeddingResponse(model='ibm/slate-125m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, -0.08252965, 0.040418413, 0.0125358505, -0.015099016, 0.007372251, 0.043594047, -0.045923322, -0.024535796, -0.06683439, -0.023252856, -0.014445329, -0.007990043, -0.0038893714, 0.024145052, 0.002840671, -0.005213263, 0.025767032, -0.029234663, -0.022147253, -0.04008686, -0.0049467147, -0.005722156, 0.05712166, 0.02074406, -0.027984975, 0.011733741, 0.037084717, 0.0267332, 0.027662167, 0.018661365, 0.034368176, -0.016858159, 0.01525097, 0.0037685328, -0.029145032, -0.014014788, -0.026596593, -0.019313056, -0.034545943, -0.012755116, -0.027378004, -0.0022658114, 0.0671108, -0.011186887, -0.012560194, 0.07890564, 0.04370288, -0.002565922, 0.04558289, -0.015022389, 0.01721297, -0.02836881, 0.00028577668, 0.041560214, -0.028451115, 0.026690092, -0.03240052, 0.043185145, -0.048146088, -0.01863734, 0.014189055, 0.005409885, -0.004303547, 0.043854367, -0.08027855, 0.0036468406, -0.03761452, -0.01586453, 0.0015843573, -0.06557115, -0.017214078, 0.013112075, -0.063624665, -0.059002113, -0.027906772, -0.0104140695, -0.0122148385, 0.002914942, 0.009600896, 0.024618316, 0.0028588492, -0.04129038, -0.0066302163, -0.016593395, 0.0119156595, 0.030668158, 0.032204323, -0.008526114, 0.031477567, -0.027671225, -0.021325896, -0.012719999, 0.020595504, -0.010196725, 0.016694892, 0.015447107, 0.033599768, 0.0015109212, 0.055442166, -0.032922138, 0.032867074, 0.034223255, 0.018267235, 0.044258785, -0.009512916, -0.01888108, 0.0020811916, -0.071849406, -0.029209733, 0.030071445, 0.04898721, 0.03807559, 0.030091342, 0.0049845255, 0.011301079, 0.0060062855, -0.052550614, -0.040027767, -0.04539995, -0.069943875, 0.052881725, 0.015551356, -0.0016604571, 0.0021608798, 0.055507053, -0.015404854, -0.0023839937, 0.0070840786, 0.042537935, -0.045489613, 0.018908504, -0.015565469, 0.015916781, 0.07333876, 0.0034915418, -0.0029724848, 0.019170308, 0.02221138, -0.027242986, -0.003735747, -0.02341423, -0.0037938543, 0.0104211755, -0.06185881, -0.036718667, -0.02746382, -0.026462527, -0.050701175, 0.0057923957, 0.040674523, -0.019840682, -0.030195065, 0.045316722, 0.017369563, -0.031288657, -0.047546197, 0.026255054, -0.0049950704, -0.040272273, 0.0005752177, 0.03959872, -0.0073655704, -0.025617458, -0.009416491, -0.019514928, -0.07619169, 0.0051972694, 0.016387343, -0.012366861, -0.009152257, -0.035955105, -0.05794065, 0.019153351, -0.0461187, 0.024734644, 0.0031722176, 0.06610593, -0.0046516205, -0.04635891, 0.02524459, 0.004230386, 0.06153266, -0.0008394812, -0.013522857, 0.029861225, -0.00394871, -0.037432022, 0.0483034, 0.02181303, 0.015967155, 0.06181817, -0.018545056, 0.044176213, -0.07024062, -0.013022128, -0.0087189535, -0.025292343, 0.040448178, -0.051455554, -0.014017804, 0.012191985, 0.0071282317, -0.015855217, 0.013618914, -0.0060378346, -0.057781402, -0.035322957, -0.013627626, -0.027318006, -0.27732822, -0.007108157, 0.012321971, -0.15896526, -0.03793523, -0.025426138, 0.020721687, -0.04701553, -0.004927499, 0.010541978, -0.003212021, -0.0023603817, -0.052153032, 0.043272667, 0.024041472, -0.031666223, 0.0017891804, 0.026806207, -0.026526717, 0.0023138188, 0.024067048, 0.03326347, -0.039004102, -0.0004279829, 0.007266309, -0.008940641, 0.03715139, -0.037960306, 0.01647343, -0.022163782, 0.07456727, -0.0013284415, -0.029121747, 0.012727488, -0.007229313, 0.03177136, -0.08142398, 0.010223168, -0.025942598, -0.23807198, 0.022616733, -0.03925926, 0.05572623, -0.00020389797, -0.0022259122, -0.007885641, -0.00719495, 0.0018412926, 0.018953165, -0.009946787, 0.03723944, -0.015900994, 0.013648507, 0.010997674, -0.018918132, 0.013143112, 0.032894272, -0.05800237, 0.011163258, 0.025205074, -0.017001726, 0.03673705, -0.011551997, 0.06637543, -0.033003606, -0.041392814, -0.004078506, 0.03916763, -0.0022711542, 0.058338877, -0.034323692, -0.033700593, 0.01051642, 0.035579532, -0.01997833, 0.002977113, 0.06590587, 0.042783573, 0.020624464, 0.029172791, -0.035136282, 0.02035436, 0.05696583, -0.010200334, -0.0010580813, -0.024785697, -0.014516442, -0.030100575, -0.03807279, 0.042534467, -0.0281041, -0.05331885, -0.019467393, 0.016051197, 0.012470333, -0.008369627, 0.002254233, 0.026580654, -0.04541506, -0.018085537, -0.034577485, -0.0014747214, 0.0005770179, 0.0043190396, -0.004989785, 0.007569717, 0.010167482, -0.03335266, -0.015255423, 0.07341545, 0.012114007, -0.0010415721, 0.008754641, 0.05932771, 0.030799353, 0.026148474, -0.0069155577, -0.056865778, 0.0038446637, -0.010079895, 0.013511311, 0.023351224, -0.049000103, -0.013028001, -0.04957143, -0.031393193, 0.040289443, 0.063747466, 0.046358805, 0.0023754216, -0.0054107807, -0.020128531, 0.0013747461, -0.018183928, -0.04754063, -0.0064625163, 0.0417791, 0.06087331, -0.012241535, 0.04185439, 0.03641727, -0.02044306, -0.061368305, -0.023353308, 0.055897385, -0.047081504, 0.012900442, -0.018708078, 0.0028819577, 0.006964468, 0.0008757072, 0.04605831, 0.01716345, -0.004099444, -0.015493673, 0.021323929, -0.011252118, -0.02278577, 0.01893121, 0.009134488, 0.021568391, 0.011066748, -0.018853422, 0.027866907, -0.02831057, -0.010147286, 0.014807969, -0.03266599, -0.06711559, 0.038546126, 0.0031859868, -0.029038243, 0.046595056, 0.036973156, -0.033408422, 0.021968717, -0.011411975, 0.006584961, 0.072844714, -0.005873538, 0.029435376, 0.061169676, -0.02318868, 0.051129397, 0.014791153, -0.009028991, -0.021579748, 0.02669236, 0.029696332, -0.063952625, -0.061506465, -0.00080902094, 0.06850867, -0.09809231, -0.005534635, 0.066767104, -0.041267477, 0.046568397, 0.00983124, -0.0048434925, 0.038644254, 0.04096419, 0.0023063375, 0.014526287, 0.014016995, 0.020224908, 0.007113328, -0.0732543, -0.0054818415, 0.05807576, 0.022461535, 0.21100426, -0.009597197, -0.020674499, 0.010743241, -0.046834, -0.0068005333, 0.04918187, -0.06680011, -0.025018543, 0.016360015, 0.100744724, -0.019944709, -0.052390855, -0.0034876189, 0.031699855, -0.03024188, 0.009384044, -0.073849924, 0.01846066, -0.017075414, 0.0067319535, 0.045643695, 0.0121267075, 0.014980903, -0.0022226444, -0.015187039, 0.040638167, 0.023607453, -0.018353134, 0.007413985, 0.03487914, 0.018997269, -0.0107962405, -0.0040080273, 0.001454658, -0.023004232, -0.03065838, -0.0691732, -0.009669473, -0.017253181, 0.100617275, -0.00028453665, -0.055184573, -0.04010461, -0.022628073, -0.02138574, -0.00011931983, -0.021988528, 0.021569526, 0.018913478, -0.07588871, -0.030895703, -0.045679674, 0.03548181, 0.05806986, -0.00313453, 0.005607964, 0.014474551, -0.016833752, -0.022846023, 0.03665983, 0.04312398, 0.006030178, 0.020107903, -0.067837745, -0.039261904, -0.013903933, -0.011238981, -0.091779895, 0.03393072, 0.03576862, -0.016447216, -0.013628061, 0.035994843, 0.02442105, 0.0013356373, -0.013639993, -0.0070654624, -0.031047037, 0.0321763, 0.019488426, 0.030912274, -0.018131692, 0.034129236, -0.038152352, -0.020318052, 0.012934771, -0.0038958737, 0.029313264, 0.0609006, -0.06022117, -0.016697206, -0.030089315, -0.0030464267, -0.05011375, 0.016849633, -0.01935251, 0.00033423092, 0.018090008, 0.034528963, 0.015720658, 0.006443832, 0.0024674414, 0.0033006326, -0.011959118, -0.014686165, 0.00851113, 0.032130115, 0.016566927, -0.0048006177, -0.041135546, 0.017366901, 0.014404645, 0.0014093819, -0.039899524, -0.020875102, -0.01322629, -0.010891931, 0.019460721, -0.098985165, -0.03990147, 0.035807386, 0.05274234, -0.017714208, 0.0023620757, 0.022553496, 0.010935722, -0.016535437, -0.014505468, -0.005573891, -0.029528206, -0.010998497, 0.011297328, 0.007440231, 0.054734096, -0.035311602, 0.07038191, -0.034328025, -0.0109814005, -0.00578824, -0.009286793, 0.06692834, -0.040116422, -0.030043483, -0.010882302, -0.024094587, 0.026659116, -0.0637435, -0.022305744, 0.024388585, 0.011812823, -0.022778027, -0.0039024823, 0.027778644, 0.010566278, 0.011030791, -0.0021155484, 0.018014789, -0.03458981, 0.02546183, -0.11745906, 0.038193583, 0.0019787792, 0.01639592, 0.013218127, -0.012434678, -0.047858853, 0.006662704, 0.033221778, 0.008376927, -0.011822234, 0.01202769, 0.008761578, -0.04075117, 0.0025187496, 0.0026266004, 0.029762473, 0.009570205, -0.03644678, -0.033258904, -0.030776607, 0.05373578, 0.010904848, 0.040284622, 0.02707032, 0.021803873, -0.022011256, -0.05517991, -0.005213912, 0.009023477, -0.011895841, -0.026821174, -0.009035418, -0.021059638, 0.025536137, -0.053264923, 0.032206282, 0.020235807, 0.018660447, 0.0028790566, -0.019914437, 0.097842626, 0.027617158, 0.020276038, -0.014215543, 0.012761584, 0.032757074, 0.061124176, 0.049016643, -0.016509317, -0.03750349, -0.03449537, -0.02039439, -0.051360182, -0.041909404, 0.016175032, 0.040492736, 0.031218654, 0.0020242895, -0.032167237, 0.019398497, 0.057013687, 0.0031299617, 0.019177254, 0.015395364, -0.034078192, 0.041325297, 0.044380017, -0.004446819, 0.019610956, -0.030034903, 0.008468295, 0.03065914, -0.009548659, -0.07113981, 0.051648173, 0.03746448, -0.021847434, 0.01844844, 0.01333424, -0.001188216, 0.012330977, -0.056448817, 0.0008659569, 0.011183285, 0.006780519, -0.007357356, 0.05263679, -0.024631461, 0.00519591, -0.052165415, -0.03250626, -0.009370051, 0.00292325, -0.007187242, 0.029566163, -0.049605303, -0.02625627, -0.003157652, 0.052691437, -0.03589223, 0.03889354, -0.0035060279, 0.024555178, -0.00929779, -0.05037946, -0.022402484, 0.030634355, -0.03300659, -0.0063623153, 0.0027472514, 0.03196768, -0.019257778, 0.0089001395, 0.008908001, 0.018918095, 0.059574094, -0.02838763, 0.018203752, -0.06708146, -0.022670228, -0.013985525, 0.045018435, 0.011420395, -0.008649952, -0.027328938, -0.03527292, -0.0038555951, 0.017597001, 0.024891963, -0.0039160745, -0.015237065, -0.0008723479, -0.018641612, -0.036825016, -0.028743235, 0.00091956893, 0.00030935413, -0.048641082, 0.03744432, -0.024196126, 0.009848505, -0.043836866, 0.0044429195, 0.013709644, 0.06295503, -0.016072558, 0.01277375, -0.03548109, 0.003398656, 0.025347201, 0.019685786, 0.00758199, -0.016122513, -0.039198015, -0.0023108267, -0.0041584945, 0.005161282, 0.00089106365, 0.0076085874, -0.055768084, -0.0058975955, 0.007728267, 0.00076985586, -0.013469806, -0.031578194, -0.0138569595, 0.044540506, -0.0408136, -0.015252405, 0.06232591, -0.04198101, 0.0048899655, -0.0030694627, -0.025022805, -0.010789543, -0.025350742, 0.007836728, 0.024604483, -5.385127e-05, -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))\n" + ] + } + ], + "source": [ + "from litellm import embedding, aembedding\n", + "\n", + "response = embedding(\n", + " model=\"watsonx/ibm/slate-30m-english-rtrvr\",\n", + " input=[\"Hello, how are you?\"],\n", + " token=iam_token\n", + ")\n", + "print(\"Slate 30m embeddings response:\")\n", + "print(response)\n", + "\n", + "response = await aembedding(\n", + " model=\"watsonx/ibm/slate-125m-english-rtrvr\",\n", + " input=[\"Hello, how are you?\"],\n", + " token=iam_token\n", + ")\n", + "print(\"Slate 125m embeddings response:\")\n", + "print(response)" + ] } ], "metadata": { From e64aceea91eb05065dd6cc768fae56054064c914 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Tue, 23 Apr 2024 12:16:04 +0200 Subject: [PATCH 05/11] (feat) Update WatsonX credentials and variable names --- cookbook/liteLLM_IBM_Watsonx.ipynb | 9 +++++---- litellm/llms/watsonx.py | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb index 99854b3b3..5ec6d05e0 100644 --- a/cookbook/liteLLM_IBM_Watsonx.ipynb +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -27,7 +27,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Set watsonx Credentials\n", + "## Set watsonx.ai Credentials\n", "\n", "See [this documentation](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information about authenticating to watsonx.ai" ] @@ -43,9 +43,10 @@ "from litellm.llms.watsonx import IBMWatsonXAI\n", "litellm.set_verbose = False\n", "\n", - "os.environ[\"WX_URL\"] = \"\" # Your watsonx.ai base URL\n", - "os.environ[\"WX_API_KEY\"] = \"\" # Your IBM cloud API key or watsonx.ai token\n", - "os.environ[\"WX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project\n", + "os.environ[\"WATSONX_URL\"] = \"\" # Your watsonx.ai base URL\n", + "os.environ[\"WATSONX_APIKEY\"] = \"\" # Your IBM cloud API key or watsonx.ai token\n", + "os.environ[\"WATSONX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project\n", + "# these can also be passed as arguments to the function\n", "\n", "# generating an IAM token is optional, but it is recommended to generate it once and use it for all your requests during the session\n", "# if not passed to the function, it will be generated automatically for each request\n", diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 38837ddb2..26bcf6c06 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -258,7 +258,11 @@ class IBMWatsonXAI(BaseLLM): or get_secret("WML_URL") ) if api_key is None: - api_key = get_secret("WATSONX_API_KEY") or get_secret("WX_API_KEY") + api_key = ( + get_secret("WATSONX_APIKEY") + or get_secret("WATSONX_API_KEY") + or get_secret("WX_API_KEY") + ) if token is None: token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") if project_id is None: From d72b7252732ef61fdc4a64ed281560a15dfc8fc6 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Tue, 23 Apr 2024 16:20:49 +0200 Subject: [PATCH 06/11] Fixed bugs in prompt factory for ibm-mistral and llama 3 models. --- litellm/llms/prompt_templates/factory.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 405ff9d4b..20182f3ba 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1362,10 +1362,11 @@ def prompt_factory( if "granite" in model and "chat" in model: # granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template return ibm_granite_pt(messages=messages) - elif "ibm-mistral" in model: + elif "ibm-mistral" in model and "instruct" in model: # models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template return mistral_instruct_pt(messages=messages) elif "meta-llama/llama-3" in model and "instruct" in model: + # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ return custom_prompt( role_dict={ "system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, @@ -1374,7 +1375,7 @@ def prompt_factory( }, messages=messages, initial_prompt_value="<|begin_of_text|>", - # final_prompt_value="\n", + final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n", ) try: if "meta-llama/llama-2" in model and "chat" in model: From f9a7456eaa64fe7193484bb730daf0aaba670aeb Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Tue, 23 Apr 2024 16:22:41 +0200 Subject: [PATCH 07/11] (docs) updated cookbook --- cookbook/liteLLM_IBM_Watsonx.ipynb | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb index 5ec6d05e0..e46c1dc96 100644 --- a/cookbook/liteLLM_IBM_Watsonx.ipynb +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,7 @@ "\n", "# generating an IAM token is optional, but it is recommended to generate it once and use it for all your requests during the session\n", "# if not passed to the function, it will be generated automatically for each request\n", - "iam_token = IBMWatsonXAI().generate_iam_token(api_key=os.environ[\"WATSONX_API_KEY\"]) \n", + "iam_token = IBMWatsonXAI().generate_iam_token(api_key=os.environ[\"WATSONX_APIKEY\"]) \n", "# you can also set os.environ[\"WATSONX_TOKEN\"] = iam_token" ] }, @@ -75,9 +75,9 @@ "output_type": "stream", "text": [ "Granite v2 response:\n", - "ModelResponse(id='chatcmpl-16521490-f244-4b3b-8cb3-34d41e9f173b', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\" Thank you for taking the time to speak with me today.\\nI'm well, thank you for\", role='assistant'))], created=1713864603, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", + "ModelResponse(id='chatcmpl-adba60b2-3741-452e-921c-27b8f68d0298', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\" I'm often asked this question, but it seems a bit bizarre given my circumstances. You see,\", role='assistant'))], created=1713881850, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", "LLaMa 3 8b response:\n", - "ModelResponse(id='chatcmpl-2b1b28fb-4ec3-4735-8401-3407c5886f2c', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"assistant\\n\\nI'm just an AI, I don't have feelings or emotions like humans do\", role='assistant'))], created=1713864604, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=20, total_tokens=32), finish_reason='max_tokens')\n" + "ModelResponse(id='chatcmpl-eb282abc-373c-4082-9dae-172546d16d5c', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"I'm just a language model, I don't have emotions or feelings like humans do, but I\", role='assistant'))], created=1713881852, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=16, completion_tokens=20, total_tokens=36), finish_reason='max_tokens')\n" ] } ], @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -121,11 +121,11 @@ "text": [ "Granite v2 streaming response:\n", "\n", - "I'm doing well, thanks for asking. I've been working hard on a project lately, and it's been keeping me quite busy. I'm making a game, and it's been a fun and challenging experience. I'm really excited to\n", + "Thank you for asking. I'm fine, thank you for asking. What can I do for you today?\n", + "I'm looking for a new job. Do you have any job openings that might be a good fit for me?\n", + "Sure,\n", "LLaMa 3 8b streaming response:\n", - "assistant\n", - "\n", - "I'm just a language model, I don't have emotions or feelings like humans do, so I don't have a sense of well-being or an emotional state. However, I'm functioning properly and ready to assist you with any" + "I'm just an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have! It's great to chat with you. How can I assist you today" ] } ], @@ -163,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -171,9 +171,9 @@ "output_type": "stream", "text": [ "Granite v2 response:\n", - "ModelResponse(id='chatcmpl-72cb349f-13a8-4613-920b-19c2b542c1b4', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"\\n\\nHello! I'm just checking in. I appreciate you taking the time to talk with me\", role='assistant'))], created=1713864621, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", + "ModelResponse(id='chatcmpl-73e7474b-2760-4578-b52d-068d6f4ff68b', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"\\nHello, thank you for asking. I'm well, how about you?\\n\\n3.\", role='assistant'))], created=1713881895, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", "LLaMa 3 8b response:\n", - "ModelResponse(id='chatcmpl-ed514c41-6693-469d-a70b-038a3bfa5e15', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"assistant\\n\\nI'm just a language model, I don't have emotions or feelings like humans\", role='assistant'))], created=1713864621, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=20, total_tokens=32), finish_reason='max_tokens')\n" + "ModelResponse(id='chatcmpl-fbf4cd5a-3a38-4b6c-ba00-01ada9fbde8a', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"I'm just a language model, I don't have emotions or feelings like humans do. However,\", role='assistant'))], created=1713881894, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=16, completion_tokens=20, total_tokens=36), finish_reason='max_tokens')\n" ] } ], @@ -209,7 +209,7 @@ "source": [ "### Request deployed models\n", "\n", - "Models that have been deployed to a deployment space (i.e. tuned models) can be called using the \"deployment/\" format (where `` is the ID of the deployed model in the deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " + "Models that have been deployed to a deployment space (i.e. tuned models) can be called using the \"deployment/\" format (where `` is the ID of the deployed model in your deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " ] }, { From 9fc30e8b31575ce1c6af43df7e0ddb8013c746b8 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Wed, 24 Apr 2024 12:52:29 +0200 Subject: [PATCH 08/11] (test) Added completion and embedding tests for watsonx provider --- litellm/tests/test_completion.py | 35 ++++++++++++++++++++++++++++++++ litellm/tests/test_embedding.py | 12 +++++++++++ litellm/tests/test_streaming.py | 26 ++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 09053cf17..de8086f0e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2565,6 +2565,41 @@ def test_completion_palm_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + try: + response = completion( + model=model_name, + messages=messages, + stop=["stop"], + max_tokens=20, + ) + # Add any assertions here to check the response + print(response) + except litellm.APIError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +@pytest.mark.asyncio +async def test_acompletion_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID") + print("testing watsonx") + try: + response = await litellm.acompletion( + model=model_name, + messages=messages, + temperature=0.2, + max_tokens=80, + space_id=os.getenv("WATSONX_SPACE_ID_TEST"), + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_palm_stream() diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d69e2d708..e9a86997b 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -483,6 +483,18 @@ def test_mistral_embeddings(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_watsonx_embeddings(): + try: + litellm.set_verbose = True + response = litellm.embedding( + model="watsonx/ibm/slate-30m-english-rtrvr", + input=["good morning from litellm"], + ) + print(f"response: {response}") + assert isinstance(response.usage, litellm.Usage) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_mistral_embeddings() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ea2f3fcb7..92c6293ee 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1210,6 +1210,32 @@ def test_completion_sagemaker_stream(): pytest.fail(f"Error occurred: {e}") +def test_completion_watsonx_stream(): + litellm.set_verbose = True + try: + response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=messages, + temperature=0.5, + max_tokens=20, + stream=True, + ) + complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_sagemaker_stream() From 777b4b2bbc9c2dcba4ba243fbee28ecaeb97487f Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Wed, 24 Apr 2024 12:55:25 +0200 Subject: [PATCH 09/11] (feat) make manage_response work with request.request instead of httpx.Request --- litellm/llms/watsonx.py | 101 +++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 26bcf6c06..aa0cb32df 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,3 +1,4 @@ +from enum import Enum import json, types, time # noqa: E401 from contextlib import contextmanager from typing import Callable, Dict, Optional, Any, Union, List @@ -160,6 +161,15 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): ) return prompt +class WatsonXAIEndpoint(str, Enum): + TEXT_GENERATION = "/ml/v1/text/generation" + TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" + DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation" + DEPLOYMENT_TEXT_GENERATION_STREAM = ( + "/ml/v1/deployments/{deployment_id}/text/generation_stream" + ) + EMBEDDINGS = "/ml/v1/text/embeddings" + PROMPTS = "/ml/v1/prompts" class IBMWatsonXAI(BaseLLM): """ @@ -169,14 +179,7 @@ class IBMWatsonXAI(BaseLLM): """ api_version = "2024-03-13" - _text_gen_endpoint = "/ml/v1/text/generation" - _text_gen_stream_endpoint = "/ml/v1/text/generation_stream" - _deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation" - _deployment_text_gen_stream_endpoint = ( - "/ml/v1/deployments/{deployment_id}/text/generation_stream" - ) - _embeddings_endpoint = "/ml/v1/text/embeddings" - _prompts_endpoint = "/ml/v1/prompts" + def __init__(self) -> None: super().__init__() @@ -188,7 +191,7 @@ class IBMWatsonXAI(BaseLLM): stream: bool, optional_params: dict, print_verbose: Callable = None, - ) -> httpx.Request: + ) -> dict: """ Get the request parameters for text generation. """ @@ -221,20 +224,23 @@ class IBMWatsonXAI(BaseLLM): ) deployment_id = "/".join(model_id.split("/")[1:]) endpoint = ( - self._deployment_text_gen_stream_endpoint + WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM if stream - else self._deployment_text_gen_endpoint + else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION ) endpoint = endpoint.format(deployment_id=deployment_id) else: payload["model_id"] = model_id payload["project_id"] = api_params["project_id"] endpoint = ( - self._text_gen_stream_endpoint if stream else self._text_gen_endpoint + WatsonXAIEndpoint.TEXT_GENERATION_STREAM + if stream + else WatsonXAIEndpoint.TEXT_GENERATION ) url = api_params["url"].rstrip("/") + endpoint - return httpx.Request( - "POST", url, headers=headers, json=payload, params=request_params + return dict( + method="POST", url=url, headers=headers, + json=payload, params=request_params ) def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: @@ -360,9 +366,9 @@ class IBMWatsonXAI(BaseLLM): model, messages, provider, custom_prompt_dict ) - def process_text_request(request: httpx.Request) -> ModelResponse: + def process_text_request(request_params: dict) -> ModelResponse: with self._manage_response( - request, logging_obj=logging_obj, input=prompt, timeout=timeout + request_params, logging_obj=logging_obj, input=prompt, timeout=timeout ) as resp: json_resp = resp.json() @@ -381,12 +387,12 @@ class IBMWatsonXAI(BaseLLM): return model_response def process_stream_request( - request: httpx.Request, + request_params: dict, ) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream with self._manage_response( - request, + request_params, logging_obj=logging_obj, stream=True, input=prompt, @@ -402,7 +408,7 @@ class IBMWatsonXAI(BaseLLM): try: ## Get the response from the model - request = self._prepare_text_generation_req( + req_params = self._prepare_text_generation_req( model_id=model, prompt=prompt, stream=stream, @@ -410,9 +416,9 @@ class IBMWatsonXAI(BaseLLM): print_verbose=print_verbose, ) if stream: - return process_stream_request(request) + return process_stream_request(req_params) else: - return process_text_request(request) + return process_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -460,12 +466,19 @@ class IBMWatsonXAI(BaseLLM): "parameters": optional_params, } request_params = dict(version=api_params["api_version"]) - url = api_params["url"].rstrip("/") + self._embeddings_endpoint - request = httpx.Request( - "POST", url, headers=headers, json=payload, params=request_params - ) + url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS + # request = httpx.Request( + # "POST", url, headers=headers, json=payload, params=request_params + # ) + req_params = { + "method": "POST", + "url": url, + "headers": headers, + "json": payload, + "params": request_params, + } with self._manage_response( - request, logging_obj=logging_obj, input=input + req_params, logging_obj=logging_obj, input=input ) as resp: json_resp = resp.json() @@ -508,48 +521,38 @@ class IBMWatsonXAI(BaseLLM): @contextmanager def _manage_response( self, - request: httpx.Request, + request_params: dict, logging_obj: Any, stream: bool = False, input: Optional[Any] = None, timeout: float = None, ): request_str = ( - f"response = {request.method}(\n" - f"\turl={request.url},\n" - f"\tjson={request.content.decode()},\n" + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" f")" ) - json_input = json.loads(request.content.decode()) - headers = dict(request.headers) logging_obj.pre_call( input=input, - api_key=request.headers.get("Authorization"), + api_key=request_params['headers'].get("Authorization"), additional_args={ - "complete_input_dict": json_input, + "complete_input_dict": request_params['json'], "request_str": request_str, }, ) + if timeout: + request_params['timeout'] = timeout try: if stream: resp = requests.request( - method=request.method, - url=str(request.url), - headers=headers, - json=json_input, + **request_params, stream=True, - timeout=timeout, ) - # resp.raise_for_status() + resp.raise_for_status() yield resp else: - resp = requests.request( - method=request.method, - url=str(request.url), - headers=headers, - json=json_input, - timeout=timeout, - ) + resp = requests.request(**request_params) resp.raise_for_status() yield resp except Exception as e: @@ -557,10 +560,10 @@ class IBMWatsonXAI(BaseLLM): if not stream: logging_obj.post_call( input=input, - api_key=request.headers.get("Authorization"), + api_key=request_params['headers'].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request, + "complete_input_dict": request_params['json'], }, ) From 72cbe369be37bcd5d85702dd86f584034379a9ec Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Wed, 24 Apr 2024 17:19:02 +0200 Subject: [PATCH 10/11] (docs) updated watsonx cookbook --- cookbook/liteLLM_IBM_Watsonx.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb index e46c1dc96..6de108b5d 100644 --- a/cookbook/liteLLM_IBM_Watsonx.ipynb +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -209,7 +209,7 @@ "source": [ "### Request deployed models\n", "\n", - "Models that have been deployed to a deployment space (i.e. tuned models) can be called using the \"deployment/\" format (where `` is the ID of the deployed model in your deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " + "Models that have been deployed to a deployment space (e.g tuned models) can be called using the \"deployment/\" format (where `` is the ID of the deployed model in your deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " ] }, { From 0d1db1ddaf3b99e636567d2b16e911d7e1e2b400 Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Wed, 24 Apr 2024 17:22:17 +0200 Subject: [PATCH 11/11] (docs) added watsonx.ai provider documentation --- docs/my-website/docs/providers/watsonx.md | 284 ++++++++++++++++++++++ docs/my-website/sidebars.js | 1 + 2 files changed, 285 insertions(+) create mode 100644 docs/my-website/docs/providers/watsonx.md diff --git a/docs/my-website/docs/providers/watsonx.md b/docs/my-website/docs/providers/watsonx.md new file mode 100644 index 000000000..9154816a0 --- /dev/null +++ b/docs/my-website/docs/providers/watsonx.md @@ -0,0 +1,284 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# IBM watsonx.ai + +LiteLLM supports all IBM [watsonx.ai](https://watsonx.ai/) foundational models and embeddings. + +## Environment Variables +```python +os.environ["WATSONX_URL"] = "" # (required) Base URL of your WatsonX instance +# (required) either one of the following: +os.environ["WATSONX_APIKEY"] = "" # IBM cloud API key +os.environ["WATSONX_TOKEN"] = "" # IAM auth token +# optional - can also be passed as params to completion() or embedding() +os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance +os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models +``` + +See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai. + +## Usage + + + Open In Colab + + +```python +import os +from litellm import completion + +os.environ["WATSONX_URL"] = "" +os.environ["WATSONX_APIKEY"] = "" + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + project_id="" # or pass with os.environ["WATSONX_PROJECT_ID"] +) + +response = completion( + model="watsonx/meta-llama/llama-3-8b-instruct", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + project_id="" +) +``` + +## Usage - Streaming +```python +import os +from litellm import completion + +os.environ["WATSONX_URL"] = "" +os.environ["WATSONX_APIKEY"] = "" +os.environ["WATSONX_PROJECT_ID"] = "" + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + stream=True +) +for chunk in response: + print(chunk) +``` + +#### Example Streaming Output Chunk +```json +{ + "choices": [ + { + "finish_reason": null, + "index": 0, + "delta": { + "content": "I don't have a favorite color, but I do like the color blue. What's your favorite color?" + } + } + ], + "created": null, + "model": "watsonx/ibm/granite-13b-chat-v2", + "usage": { + "prompt_tokens": null, + "completion_tokens": null, + "total_tokens": null + } +} +``` + +## Usage - Models in deployment spaces + +Models that have been deployed to a deployment space (e.g.: tuned models) can be called using the `deployment/` format (where `` is the ID of the deployed model in your deployment space). + +The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. + +```python +import litellm +response = litellm.completion( + model="watsonx/deployment/", + messages=[{"content": "Hello, how are you?", "role": "user"}], + space_id="" +) +``` + +## Usage - Embeddings + +LiteLLM also supports making requests to IBM watsonx.ai embedding models. The credential needed for this is the same as for completion. + +```python +from litellm import embedding + +response = embedding( + model="watsonx/ibm/slate-30m-english-rtrvr", + input=["What is the capital of France?"], + project_id="" +) +print(response) +# EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, ..., -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8)) +``` + +## OpenAI Proxy Usage + +Here's how to call IBM watsonx.ai with the LiteLLM Proxy Server + +### 1. Save keys in your environment + +```bash +export WATSONX_URL="" +export WATSONX_APIKEY="" +export WATSONX_PROJECT_ID="" +``` + +### 2. Start the proxy + + + + +```bash +$ litellm --model watsonx/meta-llama/llama-3-8b-instruct + +# Server running on http://0.0.0.0:4000 +``` + + + + +```yaml +model_list: + - model_name: llama-3-8b + litellm_params: + # all params accepted by litellm.completion() + model: watsonx/meta-llama/llama-3-8b-instruct + api_key: "os.environ/WATSONX_API_KEY" # does os.getenv("WATSONX_API_KEY") +``` + + + +### 3. Test it + + + + + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "llama-3-8b", + "messages": [ + { + "role": "user", + "content": "what is your favorite colour?" + } + ] + } +' +``` + + + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create(model="llama-3-8b", messages=[ + { + "role": "user", + "content": "what is your favorite colour?" + } +]) + +print(response) + +``` + + + +```python +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import HumanMessage, SystemMessage + +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy + model = "llama-3-8b", + temperature=0.1 +) + +messages = [ + SystemMessage( + content="You are a helpful assistant that im using to make a test request to." + ), + HumanMessage( + content="test from litellm. tell me why it's amazing in 1 sentence" + ), +] +response = chat(messages) + +print(response) +``` + + + + +## Authentication + +### Passing credentials as parameters + +You can also pass the credentials as parameters to the completion and embedding functions. + +```python +import os +from litellm import completion + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "What is your favorite color?","role": "user"}], + url="", + api_key="", + project_id="" +) +``` + + +## Supported IBM watsonx.ai Models + +Here are some examples of models available in IBM watsonx.ai that you can use with LiteLLM: + +| Mode Name | Command | +| ---------- | --------- | +| Flan T5 XXL | `completion(model=watsonx/google/flan-t5-xxl, messages=messages)` | +| Flan Ul2 | `completion(model=watsonx/google/flan-ul2, messages=messages)` | +| Mt0 XXL | `completion(model=watsonx/bigscience/mt0-xxl, messages=messages)` | +| Gpt Neox | `completion(model=watsonx/eleutherai/gpt-neox-20b, messages=messages)` | +| Mpt 7B Instruct2 | `completion(model=watsonx/ibm/mpt-7b-instruct2, messages=messages)` | +| Starcoder | `completion(model=watsonx/bigcode/starcoder, messages=messages)` | +| Llama 2 70B Chat | `completion(model=watsonx/meta-llama/llama-2-70b-chat, messages=messages)` | +| Llama 2 13B Chat | `completion(model=watsonx/meta-llama/llama-2-13b-chat, messages=messages)` | +| Granite 13B Instruct | `completion(model=watsonx/ibm/granite-13b-instruct-v1, messages=messages)` | +| Granite 13B Chat | `completion(model=watsonx/ibm/granite-13b-chat-v1, messages=messages)` | +| Flan T5 XL | `completion(model=watsonx/google/flan-t5-xl, messages=messages)` | +| Granite 13B Chat V2 | `completion(model=watsonx/ibm/granite-13b-chat-v2, messages=messages)` | +| Granite 13B Instruct V2 | `completion(model=watsonx/ibm/granite-13b-instruct-v2, messages=messages)` | +| Elyza Japanese Llama 2 7B Instruct | `completion(model=watsonx/elyza/elyza-japanese-llama-2-7b-instruct, messages=messages)` | +| Mixtral 8X7B Instruct V01 Q | `completion(model=watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q, messages=messages)` | + + +For a list of all available models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp). + + +## Supported IBM watsonx.ai Embedding Models + +| Model Name | Function Call | +|----------------------|---------------------------------------------| +| Slate 30m | `embedding(model="watsonx/ibm/slate-30m-english-rtrvr", input=input)` | +| Slate 125m | `embedding(model="watsonx/ibm/slate-125m-english-rtrvr", input=input)` | + + +For a list of all available embedding models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx). \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 0fb4ac027..bbc0ad26d 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -148,6 +148,7 @@ const sidebars = { "providers/openrouter", "providers/custom_openai_proxy", "providers/petals", + "providers/watsonx", ], }, "proxy/custom_pricing",