diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 26bcf6c06..aa0cb32df 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,3 +1,4 @@ +from enum import Enum import json, types, time # noqa: E401 from contextlib import contextmanager from typing import Callable, Dict, Optional, Any, Union, List @@ -160,6 +161,15 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): ) return prompt +class WatsonXAIEndpoint(str, Enum): + TEXT_GENERATION = "/ml/v1/text/generation" + TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" + DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation" + DEPLOYMENT_TEXT_GENERATION_STREAM = ( + "/ml/v1/deployments/{deployment_id}/text/generation_stream" + ) + EMBEDDINGS = "/ml/v1/text/embeddings" + PROMPTS = "/ml/v1/prompts" class IBMWatsonXAI(BaseLLM): """ @@ -169,14 +179,7 @@ class IBMWatsonXAI(BaseLLM): """ api_version = "2024-03-13" - _text_gen_endpoint = "/ml/v1/text/generation" - _text_gen_stream_endpoint = "/ml/v1/text/generation_stream" - _deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation" - _deployment_text_gen_stream_endpoint = ( - "/ml/v1/deployments/{deployment_id}/text/generation_stream" - ) - _embeddings_endpoint = "/ml/v1/text/embeddings" - _prompts_endpoint = "/ml/v1/prompts" + def __init__(self) -> None: super().__init__() @@ -188,7 +191,7 @@ class IBMWatsonXAI(BaseLLM): stream: bool, optional_params: dict, print_verbose: Callable = None, - ) -> httpx.Request: + ) -> dict: """ Get the request parameters for text generation. """ @@ -221,20 +224,23 @@ class IBMWatsonXAI(BaseLLM): ) deployment_id = "/".join(model_id.split("/")[1:]) endpoint = ( - self._deployment_text_gen_stream_endpoint + WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM if stream - else self._deployment_text_gen_endpoint + else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION ) endpoint = endpoint.format(deployment_id=deployment_id) else: payload["model_id"] = model_id payload["project_id"] = api_params["project_id"] endpoint = ( - self._text_gen_stream_endpoint if stream else self._text_gen_endpoint + WatsonXAIEndpoint.TEXT_GENERATION_STREAM + if stream + else WatsonXAIEndpoint.TEXT_GENERATION ) url = api_params["url"].rstrip("/") + endpoint - return httpx.Request( - "POST", url, headers=headers, json=payload, params=request_params + return dict( + method="POST", url=url, headers=headers, + json=payload, params=request_params ) def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: @@ -360,9 +366,9 @@ class IBMWatsonXAI(BaseLLM): model, messages, provider, custom_prompt_dict ) - def process_text_request(request: httpx.Request) -> ModelResponse: + def process_text_request(request_params: dict) -> ModelResponse: with self._manage_response( - request, logging_obj=logging_obj, input=prompt, timeout=timeout + request_params, logging_obj=logging_obj, input=prompt, timeout=timeout ) as resp: json_resp = resp.json() @@ -381,12 +387,12 @@ class IBMWatsonXAI(BaseLLM): return model_response def process_stream_request( - request: httpx.Request, + request_params: dict, ) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream with self._manage_response( - request, + request_params, logging_obj=logging_obj, stream=True, input=prompt, @@ -402,7 +408,7 @@ class IBMWatsonXAI(BaseLLM): try: ## Get the response from the model - request = self._prepare_text_generation_req( + req_params = self._prepare_text_generation_req( model_id=model, prompt=prompt, stream=stream, @@ -410,9 +416,9 @@ class IBMWatsonXAI(BaseLLM): print_verbose=print_verbose, ) if stream: - return process_stream_request(request) + return process_stream_request(req_params) else: - return process_text_request(request) + return process_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -460,12 +466,19 @@ class IBMWatsonXAI(BaseLLM): "parameters": optional_params, } request_params = dict(version=api_params["api_version"]) - url = api_params["url"].rstrip("/") + self._embeddings_endpoint - request = httpx.Request( - "POST", url, headers=headers, json=payload, params=request_params - ) + url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS + # request = httpx.Request( + # "POST", url, headers=headers, json=payload, params=request_params + # ) + req_params = { + "method": "POST", + "url": url, + "headers": headers, + "json": payload, + "params": request_params, + } with self._manage_response( - request, logging_obj=logging_obj, input=input + req_params, logging_obj=logging_obj, input=input ) as resp: json_resp = resp.json() @@ -508,48 +521,38 @@ class IBMWatsonXAI(BaseLLM): @contextmanager def _manage_response( self, - request: httpx.Request, + request_params: dict, logging_obj: Any, stream: bool = False, input: Optional[Any] = None, timeout: float = None, ): request_str = ( - f"response = {request.method}(\n" - f"\turl={request.url},\n" - f"\tjson={request.content.decode()},\n" + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" f")" ) - json_input = json.loads(request.content.decode()) - headers = dict(request.headers) logging_obj.pre_call( input=input, - api_key=request.headers.get("Authorization"), + api_key=request_params['headers'].get("Authorization"), additional_args={ - "complete_input_dict": json_input, + "complete_input_dict": request_params['json'], "request_str": request_str, }, ) + if timeout: + request_params['timeout'] = timeout try: if stream: resp = requests.request( - method=request.method, - url=str(request.url), - headers=headers, - json=json_input, + **request_params, stream=True, - timeout=timeout, ) - # resp.raise_for_status() + resp.raise_for_status() yield resp else: - resp = requests.request( - method=request.method, - url=str(request.url), - headers=headers, - json=json_input, - timeout=timeout, - ) + resp = requests.request(**request_params) resp.raise_for_status() yield resp except Exception as e: @@ -557,10 +560,10 @@ class IBMWatsonXAI(BaseLLM): if not stream: logging_obj.post_call( input=input, - api_key=request.headers.get("Authorization"), + api_key=request_params['headers'].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request, + "complete_input_dict": request_params['json'], }, )