diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index aa0cb32df..08d7933d9 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,12 +1,13 @@ from enum import Enum import json, types, time # noqa: E401 -from contextlib import contextmanager -from typing import Callable, Dict, Optional, Any, Union, List +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List import httpx import requests import litellm -from litellm.utils import ModelResponse, get_secret, Usage +from litellm.utils import Logging, ModelResponse, Usage, get_secret +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .base import BaseLLM from .prompt_templates import factory as ptf @@ -173,14 +174,13 @@ class WatsonXAIEndpoint(str, Enum): class IBMWatsonXAI(BaseLLM): """ - Class to interface with IBM Watsonx.ai API for text generation and embeddings. + Class to interface with IBM watsonx.ai API for text generation and embeddings. Reference: https://cloud.ibm.com/apidocs/watsonx-ai """ api_version = "2024-03-13" - def __init__(self) -> None: super().__init__() @@ -239,8 +239,7 @@ class IBMWatsonXAI(BaseLLM): ) url = api_params["url"].rstrip("/") + endpoint return dict( - method="POST", url=url, headers=headers, - json=payload, params=request_params + method="POST", url=url, headers=headers, json=payload, params=request_params ) def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: @@ -307,7 +306,7 @@ class IBMWatsonXAI(BaseLLM): ) if token is None and api_key is not None: # generate the auth token - if print_verbose: + if print_verbose is not None: print_verbose("Generating IAM token for Watsonx.ai") token = self.generate_iam_token(api_key) elif token is None and api_key is None: @@ -341,8 +340,9 @@ class IBMWatsonXAI(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - logging_obj, + logging_obj: Logging, optional_params: Optional[dict] = None, + acompletion: bool = None, litellm_params: Optional[dict] = None, logger_fn=None, timeout: float = None, @@ -365,13 +365,15 @@ class IBMWatsonXAI(BaseLLM): prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) + + manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj) - def process_text_request(request_params: dict) -> ModelResponse: - with self._manage_response( - request_params, logging_obj=logging_obj, input=prompt, timeout=timeout - ) as resp: - json_resp = resp.json() - + def process_text_gen_response(json_resp: dict) -> ModelResponse: + if "results" not in json_resp: + raise WatsonXAIError( + status_code=500, + message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", + ) generated_text = json_resp["results"][0]["generated_text"] prompt_tokens = json_resp["results"][0]["input_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"] @@ -386,25 +388,52 @@ class IBMWatsonXAI(BaseLLM): ) return model_response - def process_stream_request( + def handle_text_request(request_params: dict) -> ModelResponse: + with manage_response( + request_params, input=prompt, timeout=timeout, + ) as resp: + json_resp = resp.json() + + return process_text_gen_response(json_resp) + + async def handle_text_request_async(request_params: dict) -> ModelResponse: + async with manage_response( + request_params, input=prompt, timeout=timeout, + ) as resp: + json_resp = resp.json() + return process_text_gen_response(json_resp) + + def handle_stream_request( request_params: dict, ) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - with self._manage_response( - request_params, - logging_obj=logging_obj, - stream=True, - input=prompt, - timeout=timeout, + with manage_response( + request_params, stream=True, input=prompt, timeout=timeout, ) as resp: - response = litellm.CustomStreamWrapper( + streamwrapper = litellm.CustomStreamWrapper( resp.iter_lines(), model=model, custom_llm_provider="watsonx", logging_obj=logging_obj, ) - return response + return streamwrapper + + async def handle_stream_request_async( + request_params: dict, + ) -> litellm.CustomStreamWrapper: + # stream the response - generated chunks will be handled + # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream + async with manage_response( + request_params, stream=True, input=prompt, timeout=timeout, + ) as resp: + streamwrapper = litellm.CustomStreamWrapper( + resp.aiter_lines(), + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return streamwrapper try: ## Get the response from the model @@ -415,10 +444,18 @@ class IBMWatsonXAI(BaseLLM): optional_params=optional_params, print_verbose=print_verbose, ) - if stream: - return process_stream_request(req_params) + if stream and acompletion: + # stream and async text generation + return handle_stream_request_async(req_params) + elif stream: + # streaming text generation + return handle_stream_request(req_params) + elif acompletion: + # async text generation + return handle_text_request_async(req_params) else: - return process_text_request(req_params) + # regular text generation + return handle_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -433,6 +470,7 @@ class IBMWatsonXAI(BaseLLM): model_response=None, optional_params=None, encoding=None, + aembedding=None, ): """ Send a text embedding request to the IBM Watsonx.ai API. @@ -467,9 +505,6 @@ class IBMWatsonXAI(BaseLLM): } request_params = dict(version=api_params["api_version"]) url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS - # request = httpx.Request( - # "POST", url, headers=headers, json=payload, params=request_params - # ) req_params = { "method": "POST", "url": url, @@ -477,25 +512,47 @@ class IBMWatsonXAI(BaseLLM): "json": payload, "params": request_params, } - with self._manage_response( - req_params, logging_obj=logging_obj, input=input - ) as resp: - json_resp = resp.json() - - results = json_resp.get("results", []) - embedding_response = [] - for idx, result in enumerate(results): - embedding_response.append( - {"object": "embedding", "index": idx, "embedding": result["embedding"]} + manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) + + def process_embedding_response(json_resp: dict) -> ModelResponse: + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens ) - model_response["object"] = "list" - model_response["data"] = embedding_response - model_response["model"] = model - input_tokens = json_resp.get("input_token_count", 0) - model_response.usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ) - return model_response + return model_response + + def handle_embedding_request(request_params: dict) -> ModelResponse: + with manage_response( + request_params, input=input + ) as resp: + json_resp = resp.json() + return process_embedding_response(json_resp) + + async def handle_embedding_request_async(request_params: dict) -> ModelResponse: + async with manage_response( + request_params, input=input + ) as resp: + json_resp = resp.json() + return process_embedding_response(json_resp) + + try: + if aembedding: + return handle_embedding_request_async(req_params) + else: + return handle_embedding_request(req_params) + except WatsonXAIError as e: + raise e + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) def generate_iam_token(self, api_key=None, **params): headers = {} @@ -517,53 +574,116 @@ class IBMWatsonXAI(BaseLLM): iam_access_token = json_data["access_token"] self.token = iam_access_token return iam_access_token + + def _make_response_manager( + self, + async_: bool, + logging_obj: Logging + ) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: + """ + Returns a context manager that manages the response from the request. + if async_ is True, returns an async context manager, otherwise returns a regular context manager. - @contextmanager - def _manage_response( - self, - request_params: dict, - logging_obj: Any, - stream: bool = False, - input: Optional[Any] = None, - timeout: float = None, - ): - request_str = ( - f"response = {request_params['method']}(\n" - f"\turl={request_params['url']},\n" - f"\tjson={request_params['json']},\n" - f")" - ) - logging_obj.pre_call( - input=input, - api_key=request_params['headers'].get("Authorization"), - additional_args={ - "complete_input_dict": request_params['json'], - "request_str": request_str, - }, - ) - if timeout: - request_params['timeout'] = timeout - try: - if stream: - resp = requests.request( - **request_params, - stream=True, - ) - resp.raise_for_status() - yield resp - else: - resp = requests.request(**request_params) - resp.raise_for_status() - yield resp - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: + Usage: + ```python + manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) + async with manage_response(request_params) as resp: + ... + # or + manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) + with manage_response(request_params) as resp: + ... + ``` + """ + + def pre_call( + request_params: dict, + input:Optional[Any]=None, + ): + request_str = ( + f"response = {'await ' if async_ else ''}{request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" + f")" + ) + logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params["json"], + "request_str": request_str, + }, + ) + + def post_call(resp, request_params): logging_obj.post_call( input=input, - api_key=request_params['headers'].get("Authorization"), + api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params['json'], + "complete_input_dict": request_params.get("data", request_params.get("json")), }, ) + + @contextmanager + def _manage_response( + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout: float = None, + ) -> Generator[requests.Response, None, None]: + """ + Returns a context manager that yields the response from the request. + """ + pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + resp = requests.request(**request_params) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + post_call(resp, request_params) + + + @asynccontextmanager + async def _manage_response_async( + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout: float = None, + ) -> AsyncGenerator[httpx.Response, None]: + pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + # async with AsyncHTTPHandler(timeout=timeout) as client: + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0), + ) + # async_handler.client.verify = False + if "json" in request_params: + request_params['data'] = json.dumps(request_params.pop("json", {})) + method = request_params.pop("method") + if method.upper() == "POST": + resp = await self.async_handler.post(**request_params) + else: + resp = await self.async_handler.get(**request_params) + yield resp + # await async_handler.close() + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + post_call(resp, request_params) + + if async_: + return _manage_response_async + else: + return _manage_response diff --git a/litellm/main.py b/litellm/main.py index 41794ccd5..37106348c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -70,6 +70,7 @@ from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface +from .llms.watsonx import IBMWatsonXAI from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -105,6 +106,7 @@ anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() +watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -308,6 +310,7 @@ async def acompletion( or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" + or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1865,7 +1868,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonx.IBMWatsonXAI().completion( + response = watsonxai.completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -1876,6 +1879,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + acompletion=acompletion, timeout=timeout, ) if ( @@ -2528,6 +2532,7 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "watsonx" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -2980,13 +2985,14 @@ def embedding( aembedding=aembedding, ) elif custom_llm_provider == "watsonx": - response = watsonx.IBMWatsonXAI().embedding( + response = watsonxai.embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), + aembedding=aembedding, ) else: args = locals() diff --git a/litellm/utils.py b/litellm/utils.py index 9f176c194..136df1da1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10084,6 +10084,8 @@ class CustomStreamWrapper: response_obj = self.handle_watsonx_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if getattr(model_response, "usage", None) is None: + model_response.usage = Usage() if response_obj.get("prompt_tokens") is not None: prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) @@ -10497,6 +10499,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" + or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: