forked from phoenix/litellm-mirror
fix(watsonx.py): use common litellm params for api key, api base, etc.
This commit is contained in:
parent
a76e40df73
commit
c9d7437d16
2 changed files with 49 additions and 27 deletions
|
@ -13,7 +13,7 @@ from .prompt_templates import factory as ptf
|
|||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: str = None):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
|
@ -73,7 +73,6 @@ class IBMWatsonXAIConfig:
|
|||
repetition_penalty: Optional[float] = None
|
||||
truncate_input_tokens: Optional[int] = None
|
||||
include_stop_sequences: Optional[bool] = False
|
||||
return_options: Optional[dict] = None
|
||||
return_options: Optional[Dict[str, bool]] = None
|
||||
random_seed: Optional[int] = None # e.g 42
|
||||
moderations: Optional[dict] = None
|
||||
|
@ -161,6 +160,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
)
|
||||
return prompt
|
||||
|
||||
|
||||
class WatsonXAIEndpoint(str, Enum):
|
||||
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||
|
@ -171,6 +171,7 @@ class WatsonXAIEndpoint(str, Enum):
|
|||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||
|
@ -180,7 +181,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
|
||||
api_version = "2024-03-13"
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -190,7 +190,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
prompt: str,
|
||||
stream: bool,
|
||||
optional_params: dict,
|
||||
print_verbose: Callable = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request parameters for text generation.
|
||||
|
@ -224,9 +224,9 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
deployment_id = "/".join(model_id.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM
|
||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
|
@ -239,27 +239,40 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
url = api_params["url"].rstrip("/") + endpoint
|
||||
return dict(
|
||||
method="POST", url=url, headers=headers,
|
||||
json=payload, params=request_params
|
||||
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||
)
|
||||
|
||||
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
|
||||
def _get_api_params(
|
||||
self, params: dict, print_verbose: Optional[Callable] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", None)
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop("project_id", None) # watsonx.ai project_id
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
wx_credentials = params.pop("wx_credentials", None)
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret("WATSONX_URL")
|
||||
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret("WATSONX_URL")
|
||||
or get_secret("WX_URL")
|
||||
or get_secret("WML_URL")
|
||||
)
|
||||
|
@ -297,7 +310,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get("token", token)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
|
@ -342,10 +360,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: Optional[dict] = None,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
logger_fn=None,
|
||||
timeout: float = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Send a text generation request to the IBM Watsonx.ai API.
|
||||
|
@ -379,10 +397,14 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
model_response.usage = Usage(
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
@ -525,7 +547,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
logging_obj: Any,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: float = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
request_str = (
|
||||
f"response = {request_params['method']}(\n"
|
||||
|
@ -535,14 +557,14 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params['headers'].get("Authorization"),
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params['json'],
|
||||
"complete_input_dict": request_params["json"],
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
if timeout:
|
||||
request_params['timeout'] = timeout
|
||||
request_params["timeout"] = timeout
|
||||
try:
|
||||
if stream:
|
||||
resp = requests.request(
|
||||
|
@ -560,10 +582,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
if not stream:
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params['headers'].get("Authorization"),
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params['json'],
|
||||
"complete_input_dict": request_params["json"],
|
||||
},
|
||||
)
|
||||
|
|
|
@ -1872,7 +1872,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue