Complete 'requests' library removal (#7350)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s

* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure

* refactor(watsonx/completion/handler.py): move to using base llm http handler

removes 'requests' library usage

* fix(watsonx_text/transformation.py): fix result transformation

migrates to transformation.py, for usage with base llm http handler

* fix(streaming_handler.py): migrate watsonx streaming to transformation.py

ensures streaming works with base llm http handler

* fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic

* fix(watsonx/): fix chat route post completion route refactor

* refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well

* refactor(base.py): remove requests library usage from litellm

* build(pyproject.toml): remove requests library usage

* fix: fix linting errors

* fix: fix linting errors

* fix(types/utils.py): fix validation errors for modelresponsestream

* fix(replicate/handler.py): fix linting errors

* fix(litellm_logging.py): handle modelresponsestream object

* fix(streaming_handler.py): fix modelresponsestream args

* fix: remove unused imports

* test: fix test

* fix: fix test

* test: fix test

* test: fix tests

* test: fix test

* test: fix patch target

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-22 07:21:25 -08:00 committed by GitHub
parent 8b1ea40e7b
commit 3671829e39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 2147 additions and 2279 deletions

View file

@ -145,7 +145,7 @@ from .llms.vertex_ai.vertex_embeddings.embedding_handler import VertexEmbedding
from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .llms.watsonx.common_utils import IBMWatsonXMixin
from .types.llms.openai import (
ChatCompletionAssistantMessage,
ChatCompletionAudioParam,
@ -205,7 +205,6 @@ google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
@ -2585,43 +2584,68 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider="watsonx",
)
elif custom_llm_provider == "watsonx_text":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
timeout=timeout, # type: ignore
acompletion=acompletion,
api_key = (
api_key
or optional_params.pop("apikey", None)
or get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="watsonx",
logging_obj=logging,
api_base = (
api_base
or optional_params.pop(
"url",
optional_params.pop(
"api_base", optional_params.pop("base_url", None)
),
)
or get_secret_str("WATSONX_API_BASE")
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
wx_credentials = optional_params.pop(
"wx_credentials",
optional_params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
token: Optional[str] = None
if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", None
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = response
if token is not None:
optional_params["token"] = token
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="watsonx_text",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = vllm_handler.completion(
@ -3485,6 +3509,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "gemini":
gemini_api_key = (
@ -3661,6 +3686,32 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "watsonx":
credentials = IBMWatsonXMixin.get_watsonx_credentials(
optional_params=optional_params, api_key=api_key, api_base=api_base
)
api_key = credentials["api_key"]
api_base = credentials["api_base"]
if "token" in credentials:
optional_params["token"] = credentials["token"]
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
litellm_params={},
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "xinference":
api_key = (
@ -3687,17 +3738,6 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
api_key=api_key,
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there