feat - watsonx refractoring, removed dependency, and added support for embedding calls

This commit is contained in:
Simon Sanchez Viloria 2024-04-23 11:53:38 +02:00
parent a77537ddd4
commit 74d2ba0a23
4 changed files with 477 additions and 366 deletions

View file

@ -656,7 +656,7 @@ from .llms.bedrock import (
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
from .llms.watsonx import IBMWatsonXConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *
from .exceptions import ( from .exceptions import (

View file

@ -1,27 +1,31 @@
import json, types, time import json, types, time # noqa: E401
from typing import Callable, Optional, Any, Union, List from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List
import httpx import httpx
import requests
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse from litellm.utils import ModelResponse, get_secret, Usage
from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
class WatsonxError(Exception):
def __init__(self, status_code, message): class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: str = None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request( url = url or "https://https://us-south.ml.cloud.ibm.com"
method="POST", url="https://https://us-south.ml.cloud.ibm.com" self.request = httpx.Request(method="POST", url=url)
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class IBMWatsonXConfig:
class IBMWatsonXAIConfig:
""" """
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
Supported params for all available watsonx.ai foundational models. Supported params for all available watsonx.ai foundational models.
@ -34,96 +38,64 @@ class IBMWatsonXConfig:
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
- `stop_sequences` (string[]): list of strings to use as stop sequences. - `stop_sequences` (string[]): list of strings to use as stop sequences.
- `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point.
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
- `repetition_penalty` (float): token repetition penalty during text generation. - `repetition_penalty` (float): token repetition penalty during text generation.
- `stream` (bool): If True, the model will return a stream of responses.
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks".
- `truncate_input_tokens` (integer): Truncate input tokens to this length. - `truncate_input_tokens` (integer): Truncate input tokens to this length.
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
- `random_seed` (integer): Random seed for text generation. - `random_seed` (integer): Random seed for text generation.
- `guardrails` (bool): Enable guardrails for harmful content. - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
- `guardrails_hap_params` (dict): Guardrails for harmful content. - `stream` (bool): If True, the model will return a stream of responses.
- `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information.
- `concurrency_limit` (integer): Maximum number of concurrent requests.
- `async_mode` (bool): Enable async mode.
- `verify` (bool): Verify the SSL certificate of calls to the watsonx url.
- `validate` (bool): Validate the model_id at initialization.
- `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance.
- `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with.
""" """
decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior
temperature: Optional[float] = None # decoding_method: Optional[str] = "sample"
temperature: Optional[float] = None
max_new_tokens: Optional[int] = None # litellm.max_tokens
min_new_tokens: Optional[int] = None min_new_tokens: Optional[int] = None
max_new_tokens: Optional[int] = litellm.max_tokens length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
top_k: Optional[int] = None top_k: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
random_seed: Optional[int] = None # e.g 42
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] truncate_input_tokens: Optional[int] = None
time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds) include_stop_sequences: Optional[bool] = False
return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False} return_options: Optional[dict] = None
truncate_input_tokens: Optional[int] = None # e.g 512 return_options: Optional[Dict[str, bool]] = None
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None
stream: Optional[bool] = False stream: Optional[bool] = False
# other inference params
guardrails: Optional[bool] = False # enable guardrails
guardrails_hap_params: Optional[dict] = None # guardrails for harmful content
guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information
concurrency_limit: Optional[int] = 10 # max number of concurrent requests
async_mode: Optional[bool] = False # enable async mode
verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url
validate: Optional[bool] = False # validate the model_id at initialization
model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance
watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with
def __init__( def __init__(
self, self,
decoding_method: Optional[str] = None, decoding_method: Optional[str] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None, min_new_tokens: Optional[int] = None,
max_new_tokens: Optional[ length_penalty: Optional[dict] = None,
int stop_sequences: Optional[List[str]] = None,
] = litellm.max_tokens, # petals requires max tokens to be set
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
random_seed: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
time_limit: Optional[int] = None,
return_options: Optional[dict] = None,
truncate_input_tokens: Optional[int] = None, truncate_input_tokens: Optional[int] = None,
length_penalty: Optional[dict] = None, include_stop_sequences: Optional[bool] = None,
stream: Optional[bool] = False, return_options: Optional[dict] = None,
guardrails: Optional[bool] = False, random_seed: Optional[int] = None,
guardrails_hap_params: Optional[dict] = None, moderations: Optional[dict] = None,
guardrails_pii_params: Optional[dict] = None, stream: Optional[bool] = None,
concurrency_limit: Optional[int] = 10, **kwargs,
async_mode: Optional[bool] = False,
verify: Optional[Union[bool,str]] = None,
validate: Optional[bool] = False,
model_inference: Optional[object] = None,
watsonx_client: Optional[object] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -160,133 +132,6 @@ class IBMWatsonXConfig:
] ]
def init_watsonx_model(
model_id: str,
url: Optional[str] = None,
api_key: Optional[str] = None,
project_id: Optional[str] = None,
space_id: Optional[str] = None,
wx_credentials: Optional[dict] = None,
region_name: Optional[str] = None,
verify: Optional[Union[bool,str]] = None,
validate: Optional[bool] = False,
watsonx_client: Optional[object] = None,
model_params: Optional[dict] = None,
):
"""
Initialize a watsonx.ai model for inference.
Args:
model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/<deployment_id>' and the space_id to the deploymend space should be provided.
url (str): The URL of the watsonx.ai instance.
api_key (str): The API key for the watsonx.ai instance.
project_id (str): The project ID for the watsonx.ai instance.
space_id (str): The space ID for the deployment space.
wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance.
region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south').
verify (bool): Whether to verify the SSL certificate of calls to the watsonx url.
validate (bool): Whether to validate the model_id at initialization.
watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client.
model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters).
"""
from ibm_watsonx_ai import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
if wx_credentials is not None:
if 'apikey' not in wx_credentials and 'api_key' in wx_credentials:
wx_credentials['apikey'] = wx_credentials.pop('api_key')
if 'apikey' not in wx_credentials:
raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials")
if url is None:
url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL")
if api_key is None:
api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY")
if project_id is None:
project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID")
if region_name is None:
region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME")
if space_id is None:
space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID")
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = (url, api_key, project_id, space_id, region_name)
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
url, api_key, project_id, space_id, region_name = params_to_check
### SET WATSONX URL
if url is not None or watsonx_client is not None or wx_credentials is not None:
pass
elif region_name is not None:
url = f"https://{region_name}.ml.cloud.ibm.com"
else:
raise WatsonxError(
message="Watsonx URL not set: set WX_URL env variable or in .env file",
status_code=401,
)
if watsonx_client is not None and project_id is None:
project_id = watsonx_client.project_id
if model_id.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
assert space_id is not None, "space_id is required for deployment models"
deployment_id = '/'.join(model_id.split("/")[1:])
model_id = None
else:
deployment_id = None
if watsonx_client is not None:
model = ModelInference(
model_id=model_id,
params=model_params,
api_client=watsonx_client,
project_id=project_id,
deployment_id=deployment_id,
verify=verify,
validate=validate,
space_id=space_id,
)
elif wx_credentials is not None:
model = ModelInference(
model_id=model_id,
params=model_params,
credentials=wx_credentials,
project_id=project_id,
deployment_id=deployment_id,
verify=verify,
validate=validate,
space_id=space_id,
)
elif api_key is not None:
model = ModelInference(
model_id=model_id,
params=model_params,
credentials={
"apikey": api_key,
"url": url,
},
project_id=project_id,
deployment_id=deployment_id,
verify=verify,
validate=validate,
space_id=space_id,
)
else:
raise WatsonxError(500, "WatsonX credentials not passed or could not be found.")
return model
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict: if model in custom_prompt_dict:
@ -294,8 +139,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
model_prompt_dict = custom_prompt_dict[model] model_prompt_dict = custom_prompt_dict[model]
prompt = ptf.custom_prompt( prompt = ptf.custom_prompt(
messages=messages, messages=messages,
role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")), role_dict=model_prompt_dict.get(
initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""), "role_dict", model_prompt_dict.get("roles")
),
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
bos_token=model_prompt_dict.get("bos_token", ""), bos_token=model_prompt_dict.get("bos_token", ""),
eos_token=model_prompt_dict.get("eos_token", ""), eos_token=model_prompt_dict.get("eos_token", ""),
@ -308,18 +155,176 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
elif provider == "ibm-mistralai": elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages) prompt = ptf.mistral_instruct_pt(messages=messages)
else: else:
prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx') prompt = ptf.prompt_factory(
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt return prompt
""" class IBMWatsonXAI(BaseLLM):
IBM watsonx.ai AUTH Keys/Vars """
os.environ['WX_URL'] = "" Class to interface with IBM Watsonx.ai API for text generation and embeddings.
os.environ['WX_API_KEY'] = ""
os.environ['WX_PROJECT_ID'] = ""
"""
def completion( Reference: https://cloud.ibm.com/apidocs/watsonx-ai
"""
api_version = "2024-03-13"
_text_gen_endpoint = "/ml/v1/text/generation"
_text_gen_stream_endpoint = "/ml/v1/text/generation_stream"
_deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation"
_deployment_text_gen_stream_endpoint = (
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
)
_embeddings_endpoint = "/ml/v1/text/embeddings"
_prompts_endpoint = "/ml/v1/prompts"
def __init__(self) -> None:
super().__init__()
def _prepare_text_generation_req(
self,
model_id: str,
prompt: str,
stream: bool,
optional_params: dict,
print_verbose: Callable = None,
) -> httpx.Request:
"""
Get the request parameters for text generation.
"""
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
# build auth headers
api_token = api_params.get("token")
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
# init the payload to the text generation call
payload = {
"input": prompt,
"moderations": optional_params.pop("moderations", {}),
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
# text generation endpoint deployment or model / stream or not
if model_id.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
url=api_params["url"],
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model_id.split("/")[1:])
endpoint = (
self._deployment_text_gen_stream_endpoint
if stream
else self._deployment_text_gen_endpoint
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
payload["model_id"] = model_id
payload["project_id"] = api_params["project_id"]
endpoint = (
self._text_gen_stream_endpoint if stream else self._text_gen_endpoint
)
url = api_params["url"].rstrip("/") + endpoint
return httpx.Request(
"POST", url, headers=headers, json=payload, params=request_params
)
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", None)
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop("project_id", None) # watsonx.ai project_id
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
wx_credentials = params.pop("wx_credentials", None)
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables
if url is None:
url = (
get_secret("WATSONX_URL")
or get_secret("WX_URL")
or get_secret("WML_URL")
)
if api_key is None:
api_key = get_secret("WATSONX_API_KEY") or get_secret("WX_API_KEY")
if token is None:
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
if project_id is None:
project_id = (
get_secret("WATSONX_PROJECT_ID")
or get_secret("WX_PROJECT_ID")
or get_secret("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret("WATSONX_REGION")
or get_secret("WX_REGION")
or get_secret("REGION")
)
if space_id is None:
space_id = (
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret("WATSONX_SPACE_ID")
or get_secret("WX_SPACE_ID")
or get_secret("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get("token", token)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None:
# generate the auth token
if print_verbose:
print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return {
"url": url,
"api_key": api_key,
"token": token,
"project_id": project_id,
"space_id": space_id,
"region_name": region_name,
"api_version": api_version,
}
def completion(
self,
model: str, model: str,
messages: list, messages: list,
custom_prompt_dict: dict, custom_prompt_dict: dict,
@ -327,154 +332,231 @@ def completion(
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params:Optional[dict]=None, optional_params: Optional[dict] = None,
litellm_params:Optional[dict]=None, litellm_params: Optional[dict] = None,
logger_fn=None, logger_fn=None,
timeout:float=None, timeout: float = None,
): ):
from ibm_watsonx_ai.foundation_models import Model, ModelInference """
Send a text generation request to the IBM Watsonx.ai API.
try: Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
"""
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
extra_generate_params = dict(
guardrails=optional_params.pop("guardrails", False), # Load default configs
guardrails_hap_params=optional_params.pop("guardrails_hap_params", None), config = IBMWatsonXAIConfig.get_config()
guardrails_pii_params=optional_params.pop("guardrails_pii_params", None),
concurrency_limit=optional_params.pop("concurrency_limit", 10),
async_mode=optional_params.pop("async_mode", False),
)
if timeout is not None and optional_params.get("time_limit") is None:
# the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds)
optional_params['time_limit'] = max(0, int(timeout*1000))
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
# LOAD CONFIG
config = IBMWatsonXConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
model_inference = optional_params.pop("model_inference", None) # Make prompt to send to model
if model_inference is None:
# INIT MODEL
model_client:ModelInference = init_watsonx_model(
model_id=model,
url=optional_params.pop("url", None),
api_key=optional_params.pop("api_key", None),
project_id=optional_params.pop("project_id", None),
space_id=optional_params.pop("space_id", None),
wx_credentials=optional_params.pop("wx_credentials", None),
region_name=optional_params.pop("region_name", None),
verify=optional_params.pop("verify", None),
validate=optional_params.pop("validate", False),
watsonx_client=optional_params.pop("watsonx_client", None),
model_params=optional_params,
)
else:
model_client:ModelInference = model_inference
model = model_client.model_id
# MAKE PROMPT
provider = model.split("/")[0] provider = model.split("/")[0]
model_name = '/'.join(model.split("/")[1:]) # model_name = "/".join(model.split("/")[1:])
prompt = convert_messages_to_prompt( prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
## COMPLETION CALL
if stream is True:
request_str = (
"response = model.generate_text_stream(\n"
f"\tprompt={prompt},\n"
"\traw_response=True\n)"
)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
# remove params that are not needed for streaming
del extra_generate_params["async_mode"]
del extra_generate_params["concurrency_limit"]
# make generate call
response = model_client.generate_text_stream(
prompt=prompt,
raw_response=True,
**extra_generate_params
)
return litellm.CustomStreamWrapper(
response,
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
else:
try:
## LOGGING
request_str = (
"response = model.generate(\n"
f"\tprompt={prompt},\n"
"\traw_response=True\n)"
)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = model_client.generate(
prompt=prompt,
**extra_generate_params
)
except Exception as e:
raise WatsonxError(status_code=500, message=str(e))
## LOGGING def process_text_request(request: httpx.Request) -> ModelResponse:
logging_obj.post_call( with self._manage_response(
input=prompt, request, logging_obj=logging_obj, input=prompt, timeout=timeout
api_key="", ) as resp:
original_response=json.dumps(response), json_resp = resp.json()
additional_args={"complete_input_dict": optional_params},
)
print_verbose(f"raw model_response: {response}")
## BUILD RESPONSE OBJECT
output_text = response['results'][0]['generated_text']
try: generated_text = json_resp["results"][0]["generated_text"]
if ( prompt_tokens = json_resp["results"][0]["input_token_count"]
len(output_text) > 0 completion_tokens = json_resp["results"][0]["generated_token_count"]
and hasattr(model_response.choices[0], "message") model_response["choices"][0]["message"]["content"] = generated_text
): model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["choices"][0]["message"]["content"] = output_text model_response["created"] = int(time.time())
model_response["finish_reason"] = response['results'][0]['stop_reason'] model_response["model"] = model
prompt_tokens = response['results'][0]['input_token_count'] model_response.usage = Usage(
completion_tokens = response['results'][0]['generated_token_count']
else:
raise Exception()
except:
raise WatsonxError(
message=json.dumps(output_text),
status_code=500,
)
model_response['created'] = int(time.time())
model_response['model'] = model_name
usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage
return model_response return model_response
except WatsonxError as e:
def process_stream_request(
request: httpx.Request,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response(
request,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
response = litellm.CustomStreamWrapper(
resp.iter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return response
try:
## Get the response from the model
request = self._prepare_text_generation_req(
model_id=model,
prompt=prompt,
stream=stream,
optional_params=optional_params,
print_verbose=print_verbose,
)
if stream:
return process_stream_request(request)
else:
return process_text_request(request)
except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
raise WatsonxError(status_code=500, message=str(e)) raise WatsonXAIError(status_code=500, message=str(e))
def embedding(
self,
model: str,
input: Union[list, str],
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
"""
if optional_params is None:
optional_params = {}
# Load default configs
config = IBMWatsonXAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
def embedding(): # Load auth variables from environment variables
# logic for parsing in - calling - parsing out model embedding calls if isinstance(input, str):
pass input = [input]
if api_key is not None:
optional_params["api_key"] = api_key
api_params = self._get_api_params(optional_params)
# build auth headers
api_token = api_params.get("token")
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# init the payload to the text generation call
payload = {
"inputs": input,
"model_id": model,
"project_id": api_params["project_id"],
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + self._embeddings_endpoint
request = httpx.Request(
"POST", url, headers=headers, json=payload, params=request_params
)
with self._manage_response(
request, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
return model_response
def generate_iam_token(self, api_key=None, **params):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
response = httpx.post(
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
iam_access_token = json_data["access_token"]
self.token = iam_access_token
return iam_access_token
@contextmanager
def _manage_response(
self,
request: httpx.Request,
logging_obj: Any,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
):
request_str = (
f"response = {request.method}(\n"
f"\turl={request.url},\n"
f"\tjson={request.content.decode()},\n"
f")"
)
json_input = json.loads(request.content.decode())
headers = dict(request.headers)
logging_obj.pre_call(
input=input,
api_key=request.headers.get("Authorization"),
additional_args={
"complete_input_dict": json_input,
"request_str": request_str,
},
)
try:
if stream:
resp = requests.request(
method=request.method,
url=str(request.url),
headers=headers,
json=json_input,
stream=True,
timeout=timeout,
)
# resp.raise_for_status()
yield resp
else:
resp = requests.request(
method=request.method,
url=str(request.url),
headers=headers,
json=json_input,
timeout=timeout,
)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
logging_obj.post_call(
input=input,
api_key=request.headers.get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request,
},
)

View file

@ -1862,7 +1862,7 @@ def completion(
response = response response = response
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.completion( response = watsonx.IBMWatsonXAI().completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -2976,6 +2976,15 @@ def embedding(
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
else: else:
args = locals() args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}") raise ValueError(f"No valid embedding model args passed in - {args}")

View file

@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"presence_penalty", "presence_penalty",
] ]
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXConfig().get_supported_openai_params() return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
def get_formatted_prompt( def get_formatted_prompt(
@ -9682,20 +9682,31 @@ class CustomStreamWrapper:
def handle_watsonx_stream(self, chunk): def handle_watsonx_stream(self, chunk):
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
pass parsed_response = chunk
elif isinstance(chunk, str): elif isinstance(chunk, (str, bytes)):
chunk = json.loads(chunk) if isinstance(chunk, bytes):
result = chunk.get("results", []) chunk = chunk.decode("utf-8")
if len(result) > 0: if 'generated_text' in chunk:
text = result[0].get("generated_text", "") response = chunk.replace('data: ', '').strip()
finish_reason = result[0].get("stop_reason") parsed_response = json.loads(response)
else:
return {"text": "", "is_finished": False}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(f"Unable to parse response. Original response: {chunk}")
results = parsed_response.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != 'not_finished' is_finished = finish_reason != 'not_finished'
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", None),
"completion_tokens": results[0].get("generated_token_count", None),
} }
return "" return {"text": "", "is_finished": False}
except Exception as e: except Exception as e:
raise e raise e
@ -9957,6 +9968,15 @@ class CustomStreamWrapper:
response_obj = self.handle_watsonx_stream(chunk) response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":