mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Complete 'requests' library removal (#7350)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure * refactor(watsonx/completion/handler.py): move to using base llm http handler removes 'requests' library usage * fix(watsonx_text/transformation.py): fix result transformation migrates to transformation.py, for usage with base llm http handler * fix(streaming_handler.py): migrate watsonx streaming to transformation.py ensures streaming works with base llm http handler * fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic * fix(watsonx/): fix chat route post completion route refactor * refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well * refactor(base.py): remove requests library usage from litellm * build(pyproject.toml): remove requests library usage * fix: fix linting errors * fix: fix linting errors * fix(types/utils.py): fix validation errors for modelresponsestream * fix(replicate/handler.py): fix linting errors * fix(litellm_logging.py): handle modelresponsestream object * fix(streaming_handler.py): fix modelresponsestream args * fix: remove unused imports * test: fix test * fix: fix test * test: fix test * test: fix tests * test: fix test * test: fix patch target * test: fix test
This commit is contained in:
parent
8b1ea40e7b
commit
3671829e39
39 changed files with 2147 additions and 2279 deletions
|
@ -1,13 +1,15 @@
|
|||
from typing import Callable, Dict, Optional, Union, cast
|
||||
from typing import Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
|
||||
|
||||
|
||||
class WatsonXAIError(BaseLLMException):
|
||||
|
@ -65,18 +67,20 @@ def generate_iam_token(api_key=None, **params) -> str:
|
|||
return cast(str, result)
|
||||
|
||||
|
||||
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
|
||||
if token is not None:
|
||||
return token
|
||||
token = generate_iam_token(api_key)
|
||||
return token
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
|
@ -86,29 +90,8 @@ def _get_api_params(
|
|||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
|
||||
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
|
@ -129,34 +112,6 @@ def _get_api_params(
|
|||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
|
@ -164,11 +119,147 @@ def _get_api_params(
|
|||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
token=cast(str, token),
|
||||
project_id=project_id,
|
||||
space_id=space_id,
|
||||
region_name=region_name,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
|
||||
def convert_watsonx_messages_to_prompt(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
provider: str,
|
||||
custom_prompt_dict: Dict,
|
||||
) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_dict = custom_prompt_dict[model]
|
||||
prompt = ptf.custom_prompt(
|
||||
messages=messages,
|
||||
role_dict=model_prompt_dict.get(
|
||||
"role_dict", model_prompt_dict.get("roles")
|
||||
),
|
||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt: str = ptf.prompt_factory( # type: ignore
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# Mixin class for shared IBM Watson X functionality
|
||||
class IBMWatsonXMixin:
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
token = cast(Optional[str], optional_params.get("token"))
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
else:
|
||||
token = _generate_watsonx_token(api_key=api_key, token=token)
|
||||
# build auth headers
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
def _get_base_url(self, api_base: Optional[str]) -> str:
|
||||
url = (
|
||||
api_base
|
||||
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return url
|
||||
|
||||
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
|
||||
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
|
||||
url = url + f"?version={api_version}"
|
||||
|
||||
return url
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return WatsonXAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_watsonx_credentials(
|
||||
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
|
||||
) -> WatsonXCredentials:
|
||||
api_key = (
|
||||
api_key
|
||||
or optional_params.pop("apikey", None)
|
||||
or get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop(
|
||||
"url",
|
||||
optional_params.pop("api_base", optional_params.pop("base_url", None)),
|
||||
)
|
||||
or get_secret_str("WATSONX_API_BASE")
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
wx_credentials = optional_params.pop(
|
||||
"wx_credentials",
|
||||
optional_params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
|
||||
token: Optional[str] = None
|
||||
if wx_credentials is not None:
|
||||
api_base = wx_credentials.get("url", api_base)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", None
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
if api_key is None or not isinstance(api_key, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
|
||||
)
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return WatsonXCredentials(
|
||||
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue