mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(watsonx.py): use common litellm params for api key, api base, etc.
This commit is contained in:
parent
a76e40df73
commit
c9d7437d16
2 changed files with 49 additions and 27 deletions
|
@ -13,7 +13,7 @@ from .prompt_templates import factory as ptf
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAIError(Exception):
|
class WatsonXAIError(Exception):
|
||||||
def __init__(self, status_code, message, url: str = None):
|
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||||
|
@ -73,7 +73,6 @@ class IBMWatsonXAIConfig:
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
truncate_input_tokens: Optional[int] = None
|
truncate_input_tokens: Optional[int] = None
|
||||||
include_stop_sequences: Optional[bool] = False
|
include_stop_sequences: Optional[bool] = False
|
||||||
return_options: Optional[dict] = None
|
|
||||||
return_options: Optional[Dict[str, bool]] = None
|
return_options: Optional[Dict[str, bool]] = None
|
||||||
random_seed: Optional[int] = None # e.g 42
|
random_seed: Optional[int] = None # e.g 42
|
||||||
moderations: Optional[dict] = None
|
moderations: Optional[dict] = None
|
||||||
|
@ -161,6 +160,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAIEndpoint(str, Enum):
|
class WatsonXAIEndpoint(str, Enum):
|
||||||
TEXT_GENERATION = "/ml/v1/text/generation"
|
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||||
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||||
|
@ -171,6 +171,7 @@ class WatsonXAIEndpoint(str, Enum):
|
||||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||||
PROMPTS = "/ml/v1/prompts"
|
PROMPTS = "/ml/v1/prompts"
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXAI(BaseLLM):
|
class IBMWatsonXAI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||||
|
@ -180,7 +181,6 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
|
|
||||||
api_version = "2024-03-13"
|
api_version = "2024-03-13"
|
||||||
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -190,7 +190,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
print_verbose: Callable = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Get the request parameters for text generation.
|
Get the request parameters for text generation.
|
||||||
|
@ -224,9 +224,9 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
)
|
)
|
||||||
deployment_id = "/".join(model_id.split("/")[1:])
|
deployment_id = "/".join(model_id.split("/")[1:])
|
||||||
endpoint = (
|
endpoint = (
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM
|
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||||
if stream
|
if stream
|
||||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION
|
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||||
)
|
)
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
else:
|
else:
|
||||||
|
@ -239,27 +239,40 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
)
|
)
|
||||||
url = api_params["url"].rstrip("/") + endpoint
|
url = api_params["url"].rstrip("/") + endpoint
|
||||||
return dict(
|
return dict(
|
||||||
method="POST", url=url, headers=headers,
|
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||||
json=payload, params=request_params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
|
def _get_api_params(
|
||||||
|
self, params: dict, print_verbose: Optional[Callable] = None
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||||
"""
|
"""
|
||||||
# Load auth variables from params
|
# Load auth variables from params
|
||||||
url = params.pop("url", None)
|
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||||
api_key = params.pop("apikey", None)
|
api_key = params.pop("apikey", None)
|
||||||
token = params.pop("token", None)
|
token = params.pop("token", None)
|
||||||
project_id = params.pop("project_id", None) # watsonx.ai project_id
|
project_id = params.pop(
|
||||||
|
"project_id", params.pop("watsonx_project", None)
|
||||||
|
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||||
region_name = params.pop("region_name", params.pop("region", None))
|
region_name = params.pop("region_name", params.pop("region", None))
|
||||||
wx_credentials = params.pop("wx_credentials", None)
|
if region_name is None:
|
||||||
|
region_name = params.pop(
|
||||||
|
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||||
|
) # consistent with how vertex ai + aws regions are accepted
|
||||||
|
wx_credentials = params.pop(
|
||||||
|
"wx_credentials",
|
||||||
|
params.pop(
|
||||||
|
"watsonx_credentials", None
|
||||||
|
), # follow {provider}_credentials, same as vertex ai
|
||||||
|
)
|
||||||
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||||
# Load auth variables from environment variables
|
# Load auth variables from environment variables
|
||||||
if url is None:
|
if url is None:
|
||||||
url = (
|
url = (
|
||||||
get_secret("WATSONX_URL")
|
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||||
|
or get_secret("WATSONX_URL")
|
||||||
or get_secret("WX_URL")
|
or get_secret("WX_URL")
|
||||||
or get_secret("WML_URL")
|
or get_secret("WML_URL")
|
||||||
)
|
)
|
||||||
|
@ -297,7 +310,12 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
api_key = wx_credentials.get(
|
api_key = wx_credentials.get(
|
||||||
"apikey", wx_credentials.get("api_key", api_key)
|
"apikey", wx_credentials.get("api_key", api_key)
|
||||||
)
|
)
|
||||||
token = wx_credentials.get("token", token)
|
token = wx_credentials.get(
|
||||||
|
"token",
|
||||||
|
wx_credentials.get(
|
||||||
|
"watsonx_token", token
|
||||||
|
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||||
|
)
|
||||||
|
|
||||||
# verify that all required credentials are present
|
# verify that all required credentials are present
|
||||||
if url is None:
|
if url is None:
|
||||||
|
@ -342,10 +360,10 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: Optional[dict] = None,
|
optional_params: dict,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout: float = None,
|
timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send a text generation request to the IBM Watsonx.ai API.
|
Send a text generation request to the IBM Watsonx.ai API.
|
||||||
|
@ -379,10 +397,14 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
model_response.usage = Usage(
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@ -525,7 +547,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
input: Optional[Any] = None,
|
input: Optional[Any] = None,
|
||||||
timeout: float = None,
|
timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
request_str = (
|
request_str = (
|
||||||
f"response = {request_params['method']}(\n"
|
f"response = {request_params['method']}(\n"
|
||||||
|
@ -535,14 +557,14 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
)
|
)
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request_params['headers'].get("Authorization"),
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": request_params['json'],
|
"complete_input_dict": request_params["json"],
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if timeout:
|
if timeout:
|
||||||
request_params['timeout'] = timeout
|
request_params["timeout"] = timeout
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
resp = requests.request(
|
resp = requests.request(
|
||||||
|
@ -560,10 +582,10 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
if not stream:
|
if not stream:
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request_params['headers'].get("Authorization"),
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
original_response=json.dumps(resp.json()),
|
original_response=json.dumps(resp.json()),
|
||||||
additional_args={
|
additional_args={
|
||||||
"status_code": resp.status_code,
|
"status_code": resp.status_code,
|
||||||
"complete_input_dict": request_params['json'],
|
"complete_input_dict": request_params["json"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -1872,7 +1872,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue