fix(watsonx.py): use common litellm params for api key, api base, etc.

This commit is contained in:
Krrish Dholakia 2024-04-27 10:15:27 -07:00
parent a76e40df73
commit c9d7437d16
2 changed files with 49 additions and 27 deletions

View file

@ -13,7 +13,7 @@ from .prompt_templates import factory as ptf
class WatsonXAIError(Exception): class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: str = None): def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com" url = url or "https://https://us-south.ml.cloud.ibm.com"
@ -73,7 +73,6 @@ class IBMWatsonXAIConfig:
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False include_stop_sequences: Optional[bool] = False
return_options: Optional[dict] = None
return_options: Optional[Dict[str, bool]] = None return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42 random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None moderations: Optional[dict] = None
@ -161,6 +160,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
) )
return prompt return prompt
class WatsonXAIEndpoint(str, Enum): class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation" TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
@ -171,6 +171,7 @@ class WatsonXAIEndpoint(str, Enum):
EMBEDDINGS = "/ml/v1/text/embeddings" EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts" PROMPTS = "/ml/v1/prompts"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM Watsonx.ai API for text generation and embeddings. Class to interface with IBM Watsonx.ai API for text generation and embeddings.
@ -180,7 +181,6 @@ class IBMWatsonXAI(BaseLLM):
api_version = "2024-03-13" api_version = "2024-03-13"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -190,7 +190,7 @@ class IBMWatsonXAI(BaseLLM):
prompt: str, prompt: str,
stream: bool, stream: bool,
optional_params: dict, optional_params: dict,
print_verbose: Callable = None, print_verbose: Optional[Callable] = None,
) -> dict: ) -> dict:
""" """
Get the request parameters for text generation. Get the request parameters for text generation.
@ -224,9 +224,9 @@ class IBMWatsonXAI(BaseLLM):
) )
deployment_id = "/".join(model_id.split("/")[1:]) deployment_id = "/".join(model_id.split("/")[1:])
endpoint = ( endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
) )
endpoint = endpoint.format(deployment_id=deployment_id) endpoint = endpoint.format(deployment_id=deployment_id)
else: else:
@ -239,27 +239,40 @@ class IBMWatsonXAI(BaseLLM):
) )
url = api_params["url"].rstrip("/") + endpoint url = api_params["url"].rstrip("/") + endpoint
return dict( return dict(
method="POST", url=url, headers=headers, method="POST", url=url, headers=headers, json=payload, params=request_params
json=payload, params=request_params
) )
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: def _get_api_params(
self, params: dict, print_verbose: Optional[Callable] = None
) -> dict:
""" """
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
""" """
# Load auth variables from params # Load auth variables from params
url = params.pop("url", None) url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None) api_key = params.pop("apikey", None)
token = params.pop("token", None) token = params.pop("token", None)
project_id = params.pop("project_id", None) # watsonx.ai project_id project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None)) region_name = params.pop("region_name", params.pop("region", None))
wx_credentials = params.pop("wx_credentials", None) if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", IBMWatsonXAI.api_version) api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables # Load auth variables from environment variables
if url is None: if url is None:
url = ( url = (
get_secret("WATSONX_URL") get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL")
or get_secret("WX_URL") or get_secret("WX_URL")
or get_secret("WML_URL") or get_secret("WML_URL")
) )
@ -297,7 +310,12 @@ class IBMWatsonXAI(BaseLLM):
api_key = wx_credentials.get( api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key) "apikey", wx_credentials.get("api_key", api_key)
) )
token = wx_credentials.get("token", token) token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present # verify that all required credentials are present
if url is None: if url is None:
@ -342,10 +360,10 @@ class IBMWatsonXAI(BaseLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: Optional[dict] = None, optional_params: dict,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
logger_fn=None, logger_fn=None,
timeout: float = None, timeout: Optional[float] = None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -379,10 +397,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
model_response.usage = Usage( setattr(
prompt_tokens=prompt_tokens, model_response,
completion_tokens=completion_tokens, "usage",
total_tokens=prompt_tokens + completion_tokens, Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
@ -525,7 +547,7 @@ class IBMWatsonXAI(BaseLLM):
logging_obj: Any, logging_obj: Any,
stream: bool = False, stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,
timeout: float = None, timeout: Optional[float] = None,
): ):
request_str = ( request_str = (
f"response = {request_params['method']}(\n" f"response = {request_params['method']}(\n"
@ -535,14 +557,14 @@ class IBMWatsonXAI(BaseLLM):
) )
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=request_params['headers'].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
additional_args={ additional_args={
"complete_input_dict": request_params['json'], "complete_input_dict": request_params["json"],
"request_str": request_str, "request_str": request_str,
}, },
) )
if timeout: if timeout:
request_params['timeout'] = timeout request_params["timeout"] = timeout
try: try:
if stream: if stream:
resp = requests.request( resp = requests.request(
@ -560,10 +582,10 @@ class IBMWatsonXAI(BaseLLM):
if not stream: if not stream:
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=request_params['headers'].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request_params['json'], "complete_input_dict": request_params["json"],
}, },
) )

View file

@ -1872,7 +1872,7 @@ def completion(
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,