diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 34176a23a..6a7a51908 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,5 +1,6 @@ from enum import Enum import json, types, time # noqa: E401 +import asyncio from contextlib import asynccontextmanager, contextmanager from typing import ( Callable, @@ -393,6 +394,35 @@ class IBMWatsonXAI(BaseLLM): "api_version": api_version, } + def _process_text_gen_response( + self, json_resp: dict, model_response: Union[ModelResponse, None] = None + ) -> ModelResponse: + if "results" not in json_resp: + raise WatsonXAIError( + status_code=500, + message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", + ) + if model_response is None: + model_response = ModelResponse(model=json_resp.get("model_id", None)) + generated_text = json_resp["results"][0]["generated_text"] + prompt_tokens = json_resp["results"][0]["input_token_count"] + completion_tokens = json_resp["results"][0]["generated_token_count"] + model_response["choices"][0]["message"]["content"] = generated_text + model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] + if json_resp.get("created_at"): + model_response["created"] = datetime.fromisoformat( + json_resp["created_at"] + ).timestamp() + else: + model_response["created"] = int(time.time()) + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + def completion( self, model: str, @@ -530,6 +560,29 @@ class IBMWatsonXAI(BaseLLM): raise e except Exception as e: raise WatsonXAIError(status_code=500, message=str(e)) + + def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse: + if model_response is None: + model_response = ModelResponse(model=json_resp.get("model_id", None)) + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": result["embedding"], + } + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + ) + return model_response def embedding( self, @@ -664,127 +717,135 @@ class IBMWatsonXAI(BaseLLM): return [res["model_id"] for res in json_resp["resources"]] class RequestManager: + """ + Returns a context manager that manages the response from the request. + if async_ is True, returns an async context manager, otherwise returns a regular context manager. + + Usage: + ```python + request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"}) + request_manager = RequestManager(logging_obj=logging_obj) + with request_manager.request(request_params) as resp: + ... + # or + async with request_manager.async_request(request_params) as resp: + ... + ``` + """ + + def __init__(self, logging_obj=None): + self.logging_obj = logging_obj + + def pre_call( + self, + request_params: dict, + input: Optional[Any] = None, + is_async: Optional[bool] = False, + ): + if self.logging_obj is None: + return + request_str = ( + f"response = {'await ' if is_async else ''}{request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params.get('json')},\n" + f")" + ) + self.logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params.get("json"), + "request_str": request_str, + }, + ) + + def post_call(self, resp, request_params): + if self.logging_obj is None: + return + self.logging_obj.post_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + original_response=json.dumps(resp.json()), + additional_args={ + "status_code": resp.status_code, + "complete_input_dict": request_params.get( + "data", request_params.get("json") + ), + }, + ) + + @contextmanager + def request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> Generator[requests.Response, None, None]: """ - Returns a context manager that manages the response from the request. - if async_ is True, returns an async context manager, otherwise returns a regular context manager. - - Usage: - ```python - request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"}) - request_manager = RequestManager(logging_obj=logging_obj) - async with request_manager.request(request_params) as resp: - ... - # or - with request_manager.async_request(request_params) as resp: - ... - ``` + Returns a context manager that yields the response from the request. """ - - def __init__(self, logging_obj=None): - self.logging_obj = logging_obj - - def pre_call( - self, - request_params: dict, - input: Optional[Any] = None, - ): - if self.logging_obj is None: - return - request_str = ( - f"response = {request_params['method']}(\n" - f"\turl={request_params['url']},\n" - f"\tjson={request_params.get('json')},\n" - f")" - ) - self.logging_obj.pre_call( - input=input, - api_key=request_params["headers"].get("Authorization"), - additional_args={ - "complete_input_dict": request_params.get("json"), - "request_str": request_str, - }, - ) - - def post_call(self, resp, request_params): - if self.logging_obj is None: - return - self.logging_obj.post_call( - input=input, - api_key=request_params["headers"].get("Authorization"), - original_response=json.dumps(resp.json()), - additional_args={ - "status_code": resp.status_code, - "complete_input_dict": request_params.get( - "data", request_params.get("json") - ), - }, - ) - - @contextmanager - def request( - self, - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout=None, - ) -> Generator[requests.Response, None, None]: - """ - Returns a context manager that yields the response from the request. - """ - self.pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - resp = requests.request(**request_params) - if not resp.ok: - raise WatsonXAIError( - status_code=resp.status_code, - message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", - ) - yield resp - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - self.post_call(resp, request_params) - - @asynccontextmanager - async def async_request( - self, - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout=None, - ) -> AsyncGenerator[httpx.Response, None]: - self.pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - # async with AsyncHTTPHandler(timeout=timeout) as client: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout( - timeout=request_params.pop("timeout", 600.0), connect=5.0 - ), + self.pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + resp = requests.request(**request_params) + if not resp.ok: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", ) - # async_handler.client.verify = False - if "json" in request_params: - request_params["data"] = json.dumps(request_params.pop("json", {})) - method = request_params.pop("method") + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) + + async def async_request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> AsyncGenerator[httpx.Response, None]: + self.pre_call(request_params, input, is_async=True) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout( + timeout=request_params.pop("timeout", 600.0), connect=5.0 + ), + ) + if "json" in request_params: + request_params["data"] = json.dumps(request_params.pop("json", {})) + method = request_params.pop("method") + retries = 0 + while retries < 3: if method.upper() == "POST": resp = await self.async_handler.post(**request_params) else: resp = await self.async_handler.get(**request_params) - if resp.status_code not in [200, 201]: - raise WatsonXAIError( - status_code=resp.status_code, - message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", - ) - yield resp - # await async_handler.close() - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - self.post_call(resp, request_params) \ No newline at end of file + if resp.status_code in [429, 503, 504, 520]: + # to handle rate limiting and service unavailable errors + # see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload + await asyncio.sleep(2**retries) + retries += 1 + else: + break + if resp.is_error: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) + yield resp + # await async_handler.close() + except Exception as e: + raise e + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) diff --git a/litellm/main.py b/litellm/main.py index 2a7759e8a..d4c87feb4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -108,6 +108,7 @@ from .llms.databricks import DatabricksChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.predibase import PredibaseChatCompletion +from .llms.watsonx import IBMWatsonXAI from .llms.prompt_templates.factory import ( custom_prompt, function_call_prompt, @@ -152,6 +153,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -369,6 +371,7 @@ async def acompletion( or custom_llm_provider == "bedrock" or custom_llm_provider == "databricks" or custom_llm_provider == "clarifai" + or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -2352,7 +2355,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonx.IBMWatsonXAI().completion( + response = watsonxai.completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -2364,6 +2367,7 @@ def completion( encoding=encoding, logging_obj=logging, timeout=timeout, # type: ignore + acompletion=acompletion, ) if ( "stream" in optional_params @@ -3030,6 +3034,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" or custom_llm_provider == "databricks" + or custom_llm_provider == "watsonx" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3537,13 +3542,14 @@ def embedding( aembedding=aembedding, ) elif custom_llm_provider == "watsonx": - response = watsonx.IBMWatsonXAI().embedding( + response = watsonxai.embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), + aembedding=aembedding, ) else: args = locals()