diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 99f2d18ba..5a14e8133 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,12 +1,26 @@ from enum import Enum import json, types, time # noqa: E401 -from contextlib import contextmanager -from typing import Callable, Dict, Optional, Any, Union, List +from contextlib import asynccontextmanager, contextmanager +from typing import ( + Callable, + Dict, + Generator, + AsyncGenerator, + Iterator, + AsyncIterator, + Optional, + Any, + Union, + List, + ContextManager, + AsyncContextManager, +) import httpx # type: ignore import requests # type: ignore import litellm -from litellm.utils import ModelResponse, get_secret, Usage +from litellm.utils import ModelResponse, Usage, get_secret +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .base import BaseLLM from .prompt_templates import factory as ptf @@ -188,11 +202,12 @@ class WatsonXAIEndpoint(str, Enum): ) EMBEDDINGS = "/ml/v1/text/embeddings" PROMPTS = "/ml/v1/prompts" + AVAILABLE_MODELS = "/ml/v1/foundation_model_specs" class IBMWatsonXAI(BaseLLM): """ - Class to interface with IBM Watsonx.ai API for text generation and embeddings. + Class to interface with IBM watsonx.ai API for text generation and embeddings. Reference: https://cloud.ibm.com/apidocs/watsonx-ai """ @@ -343,7 +358,7 @@ class IBMWatsonXAI(BaseLLM): ) if token is None and api_key is not None: # generate the auth token - if print_verbose: + if print_verbose is not None: print_verbose("Generating IAM token for Watsonx.ai") token = self.generate_iam_token(api_key) elif token is None and api_key is None: @@ -378,10 +393,11 @@ class IBMWatsonXAI(BaseLLM): print_verbose: Callable, encoding, logging_obj, - optional_params: dict, - litellm_params: Optional[dict] = None, + optional_params=None, + acompletion=None, + litellm_params=None, logger_fn=None, - timeout: Optional[float] = None, + timeout=None, ): """ Send a text generation request to the IBM Watsonx.ai API. @@ -402,12 +418,12 @@ class IBMWatsonXAI(BaseLLM): model, messages, provider, custom_prompt_dict ) - def process_text_request(request_params: dict) -> ModelResponse: - with self._manage_response( - request_params, logging_obj=logging_obj, input=prompt, timeout=timeout - ) as resp: - json_resp = resp.json() - + def process_text_gen_response(json_resp: dict) -> ModelResponse: + if "results" not in json_resp: + raise WatsonXAIError( + status_code=500, + message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", + ) generated_text = json_resp["results"][0]["generated_text"] prompt_tokens = json_resp["results"][0]["input_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"] @@ -415,36 +431,70 @@ class IBMWatsonXAI(BaseLLM): model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["created"] = int(time.time()) model_response["model"] = model - setattr( - model_response, - "usage", - Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, ) + setattr(model_response, "usage", usage) return model_response - def process_stream_request( - request_params: dict, + def process_stream_response( + stream_resp: Union[Iterator[str], AsyncIterator], ) -> litellm.CustomStreamWrapper: + streamwrapper = litellm.CustomStreamWrapper( + stream_resp, + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return streamwrapper + + # create the function to manage the request to watsonx.ai + self.request_manager = RequestManager(logging_obj) + + def handle_text_request(request_params: dict) -> ModelResponse: + with self.request_manager.request( + request_params, + input=prompt, + timeout=timeout, + ) as resp: + json_resp = resp.json() + + return process_text_gen_response(json_resp) + + async def handle_text_request_async(request_params: dict) -> ModelResponse: + async with self.request_manager.async_request( + request_params, + input=prompt, + timeout=timeout, + ) as resp: + json_resp = resp.json() + return process_text_gen_response(json_resp) + + def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - with self._manage_response( + with self.request_manager.request( request_params, - logging_obj=logging_obj, stream=True, input=prompt, timeout=timeout, ) as resp: - response = litellm.CustomStreamWrapper( - resp.iter_lines(), - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) - return response + streamwrapper = process_stream_response(resp.iter_lines()) + return streamwrapper + + async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper: + # stream the response - generated chunks will be handled + # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream + async with self.request_manager.async_request( + request_params, + stream=True, + input=prompt, + timeout=timeout, + ) as resp: + streamwrapper = process_stream_response(resp.aiter_lines()) + return streamwrapper try: ## Get the response from the model @@ -455,10 +505,18 @@ class IBMWatsonXAI(BaseLLM): optional_params=optional_params, print_verbose=print_verbose, ) - if stream: - return process_stream_request(req_params) + if stream and (acompletion is True): + # stream and async text generation + return handle_stream_request_async(req_params) + elif stream: + # streaming text generation + return handle_stream_request(req_params) + elif (acompletion is True): + # async text generation + return handle_text_request_async(req_params) else: - return process_text_request(req_params) + # regular text generation + return handle_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -473,6 +531,7 @@ class IBMWatsonXAI(BaseLLM): model_response=None, optional_params=None, encoding=None, + aembedding=None, ): """ Send a text embedding request to the IBM Watsonx.ai API. @@ -507,9 +566,6 @@ class IBMWatsonXAI(BaseLLM): } request_params = dict(version=api_params["api_version"]) url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS - # request = httpx.Request( - # "POST", url, headers=headers, json=payload, params=request_params - # ) req_params = { "method": "POST", "url": url, @@ -517,25 +573,49 @@ class IBMWatsonXAI(BaseLLM): "json": payload, "params": request_params, } - with self._manage_response( - req_params, logging_obj=logging_obj, input=input - ) as resp: - json_resp = resp.json() + request_manager = RequestManager(logging_obj) - results = json_resp.get("results", []) - embedding_response = [] - for idx, result in enumerate(results): - embedding_response.append( - {"object": "embedding", "index": idx, "embedding": result["embedding"]} + def process_embedding_response(json_resp: dict) -> ModelResponse: + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": result["embedding"], + } + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, ) - model_response["object"] = "list" - model_response["data"] = embedding_response - model_response["model"] = model - input_tokens = json_resp.get("input_token_count", 0) - model_response.usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ) - return model_response + return model_response + + def handle_embedding(request_params: dict) -> ModelResponse: + with request_manager.request(request_params, input=input) as resp: + json_resp = resp.json() + return process_embedding_response(json_resp) + + async def handle_aembedding(request_params: dict) -> ModelResponse: + async with request_manager.async_request(request_params, input=input) as resp: + json_resp = resp.json() + return process_embedding_response(json_resp) + + try: + if aembedding is True: + return handle_embedding(req_params) + else: + return handle_aembedding(req_params) + except WatsonXAIError as e: + raise e + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) def generate_iam_token(self, api_key=None, **params): headers = {} @@ -558,52 +638,144 @@ class IBMWatsonXAI(BaseLLM): self.token = iam_access_token return iam_access_token - @contextmanager - def _manage_response( - self, - request_params: dict, - logging_obj: Any, - stream: bool = False, - input: Optional[Any] = None, - timeout: Optional[float] = None, - ): - request_str = ( - f"response = {request_params['method']}(\n" - f"\turl={request_params['url']},\n" - f"\tjson={request_params['json']},\n" - f")" - ) - logging_obj.pre_call( - input=input, - api_key=request_params["headers"].get("Authorization"), - additional_args={ - "complete_input_dict": request_params["json"], - "request_str": request_str, - }, - ) - if timeout: - request_params["timeout"] = timeout - try: - if stream: - resp = requests.request( - **request_params, - stream=True, - ) - resp.raise_for_status() - yield resp - else: - resp = requests.request(**request_params) - resp.raise_for_status() - yield resp - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - logging_obj.post_call( + def get_available_models(self, *, ids_only: bool = True, **params): + api_params = self._get_api_params(params) + headers = { + "Authorization": f"Bearer {api_params['token']}", + "Content-Type": "application/json", + "Accept": "application/json", + } + request_params = dict(version=api_params["api_version"]) + url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS + req_params = dict(method="GET", url=url, headers=headers, params=request_params) + with RequestManager(logging_obj=None).request(req_params) as resp: + json_resp = resp.json() + if not ids_only: + return json_resp + return [res["model_id"] for res in json_resp["resources"]] + +class RequestManager: + """ + Returns a context manager that manages the response from the request. + if async_ is True, returns an async context manager, otherwise returns a regular context manager. + + Usage: + ```python + request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"}) + request_manager = RequestManager(logging_obj=logging_obj) + async with request_manager.request(request_params) as resp: + ... + # or + with request_manager.async_request(request_params) as resp: + ... + ``` + """ + + def __init__(self, logging_obj=None): + self.logging_obj = logging_obj + + def pre_call( + self, + request_params: dict, + input: Optional[Any] = None, + ): + if self.logging_obj is None: + return + request_str = ( + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params.get('json')},\n" + f")" + ) + self.logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params.get("json"), + "request_str": request_str, + }, + ) + + def post_call(self, resp, request_params): + if self.logging_obj is None: + return + self.logging_obj.post_call( input=input, api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params["json"], + "complete_input_dict": request_params.get( + "data", request_params.get("json") + ), }, ) + + @contextmanager + def request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> Generator[requests.Response, None, None]: + """ + Returns a context manager that yields the response from the request. + """ + self.pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + resp = requests.request(**request_params) + if not resp.ok: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) + + @asynccontextmanager + async def async_request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> AsyncGenerator[httpx.Response, None]: + self.pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + # async with AsyncHTTPHandler(timeout=timeout) as client: + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout( + timeout=request_params.pop("timeout", 600.0), connect=5.0 + ), + ) + # async_handler.client.verify = False + if "json" in request_params: + request_params["data"] = json.dumps(request_params.pop("json", {})) + method = request_params.pop("method") + if method.upper() == "POST": + resp = await self.async_handler.post(**request_params) + else: + resp = await self.async_handler.get(**request_params) + if resp.status_code not in [200, 201]: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) + yield resp + # await async_handler.close() + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) \ No newline at end of file diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 04f4cc511..7628f0daf 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3236,6 +3236,24 @@ def test_completion_watsonx(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_stream_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + try: + response = completion( + model=model_name, + messages=messages, + stop=["stop"], + max_tokens=20, + stream=True + ) + for chunk in response: + print(chunk) + except litellm.APIError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + @pytest.mark.parametrize( "provider, model, project, region_name, token", @@ -3300,6 +3318,25 @@ async def test_acompletion_watsonx(): except Exception as e: pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_acompletion_stream_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + print("testing watsonx") + try: + response = await litellm.acompletion( + model=model_name, + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True + ) + # Add any assertions here to check the response + async for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_palm_stream() diff --git a/litellm/utils.py b/litellm/utils.py index 4a442542a..5f252331a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10430,7 +10430,7 @@ class CustomStreamWrapper: response = chunk.replace("data: ", "").strip() parsed_response = json.loads(response) else: - return {"text": "", "is_finished": False} + return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0} else: print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") raise ValueError( @@ -10445,8 +10445,8 @@ class CustomStreamWrapper: "text": text, "is_finished": is_finished, "finish_reason": finish_reason, - "prompt_tokens": results[0].get("input_token_count", None), - "completion_tokens": results[0].get("generated_token_count", None), + "prompt_tokens": results[0].get("input_token_count", 0), + "completion_tokens": results[0].get("generated_token_count", 0), } return {"text": "", "is_finished": False} except Exception as e: