fix(watsonx.py): use common litellm params for api key, api base, etc.

This commit is contained in:
Krrish Dholakia 2024-04-27 10:15:27 -07:00
parent a76e40df73
commit c9d7437d16
2 changed files with 49 additions and 27 deletions

View file

@ -13,7 +13,7 @@ from .prompt_templates import factory as ptf
class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: str = None):
def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code
self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com"
@ -73,7 +73,6 @@ class IBMWatsonXAIConfig:
repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False
return_options: Optional[dict] = None
return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None
@ -161,6 +160,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
)
return prompt
class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
@ -171,6 +171,7 @@ class WatsonXAIEndpoint(str, Enum):
EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts"
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
@ -180,7 +181,6 @@ class IBMWatsonXAI(BaseLLM):
api_version = "2024-03-13"
def __init__(self) -> None:
super().__init__()
@ -190,7 +190,7 @@ class IBMWatsonXAI(BaseLLM):
prompt: str,
stream: bool,
optional_params: dict,
print_verbose: Callable = None,
print_verbose: Optional[Callable] = None,
) -> dict:
"""
Get the request parameters for text generation.
@ -224,9 +224,9 @@ class IBMWatsonXAI(BaseLLM):
)
deployment_id = "/".join(model_id.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
@ -239,27 +239,40 @@ class IBMWatsonXAI(BaseLLM):
)
url = api_params["url"].rstrip("/") + endpoint
return dict(
method="POST", url=url, headers=headers,
json=payload, params=request_params
method="POST", url=url, headers=headers, json=payload, params=request_params
)
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
def _get_api_params(
self, params: dict, print_verbose: Optional[Callable] = None
) -> dict:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", None)
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop("project_id", None) # watsonx.ai project_id
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
wx_credentials = params.pop("wx_credentials", None)
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables
if url is None:
url = (
get_secret("WATSONX_URL")
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL")
or get_secret("WX_URL")
or get_secret("WML_URL")
)
@ -297,7 +310,12 @@ class IBMWatsonXAI(BaseLLM):
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get("token", token)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
@ -342,10 +360,10 @@ class IBMWatsonXAI(BaseLLM):
print_verbose: Callable,
encoding,
logging_obj,
optional_params: Optional[dict] = None,
optional_params: dict,
litellm_params: Optional[dict] = None,
logger_fn=None,
timeout: float = None,
timeout: Optional[float] = None,
):
"""
Send a text generation request to the IBM Watsonx.ai API.
@ -379,10 +397,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.usage = Usage(
setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
@ -525,7 +547,7 @@ class IBMWatsonXAI(BaseLLM):
logging_obj: Any,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
timeout: Optional[float] = None,
):
request_str = (
f"response = {request_params['method']}(\n"
@ -535,14 +557,14 @@ class IBMWatsonXAI(BaseLLM):
)
logging_obj.pre_call(
input=input,
api_key=request_params['headers'].get("Authorization"),
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params['json'],
"complete_input_dict": request_params["json"],
"request_str": request_str,
},
)
if timeout:
request_params['timeout'] = timeout
request_params["timeout"] = timeout
try:
if stream:
resp = requests.request(
@ -560,10 +582,10 @@ class IBMWatsonXAI(BaseLLM):
if not stream:
logging_obj.post_call(
input=input,
api_key=request_params['headers'].get("Authorization"),
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params['json'],
"complete_input_dict": request_params["json"],
},
)

View file

@ -1872,7 +1872,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,