From 170fd11c8208f57739da1410d546bd6a9cb996ff Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Fri, 10 May 2024 11:53:33 +0200 Subject: [PATCH 1/2] (fix) watsonx.py: Fixed linting errors and make sure stream chunk always return usage --- litellm/llms/watsonx.py | 365 ++++++++++++++++++++++++++++++---------- litellm/utils.py | 6 +- 2 files changed, 279 insertions(+), 92 deletions(-) diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 082cdb325..ad4aff4b6 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,12 +1,25 @@ from enum import Enum import json, types, time # noqa: E401 from contextlib import asynccontextmanager, contextmanager -from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List +from typing import ( + Callable, + Dict, + Generator, + AsyncGenerator, + Iterator, + AsyncIterator, + Optional, + Any, + Union, + List, + ContextManager, + AsyncContextManager, +) import httpx # type: ignore import requests # type: ignore import litellm -from litellm.utils import Logging, ModelResponse, Usage, get_secret +from litellm.utils import ModelResponse, Usage, get_secret from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .base import BaseLLM @@ -189,6 +202,7 @@ class WatsonXAIEndpoint(str, Enum): ) EMBEDDINGS = "/ml/v1/text/embeddings" PROMPTS = "/ml/v1/prompts" + AVAILABLE_MODELS = "/ml/v1/foundation_model_specs" class IBMWatsonXAI(BaseLLM): @@ -378,12 +392,12 @@ class IBMWatsonXAI(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - logging_obj: Logging, - optional_params: Optional[dict] = None, - acompletion: bool = None, - litellm_params: Optional[dict] = None, + logging_obj, + optional_params=None, + acompletion=None, + litellm_params=None, logger_fn=None, - timeout: Optional[float] = None, + timeout=None, ): """ Send a text generation request to the IBM Watsonx.ai API. @@ -403,8 +417,6 @@ class IBMWatsonXAI(BaseLLM): prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) - - manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj) def process_text_gen_response(json_resp: dict) -> ModelResponse: if "results" not in json_resp: @@ -419,62 +431,72 @@ class IBMWatsonXAI(BaseLLM): model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["created"] = int(time.time()) model_response["model"] = model - setattr( - model_response, - "usage", - Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, ) + setattr(model_response, "usage", usage) return model_response + def process_stream_response( + stream_resp: Union[Iterator[str], AsyncIterator], + ) -> litellm.CustomStreamWrapper: + streamwrapper = litellm.CustomStreamWrapper( + stream_resp, + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return streamwrapper + + # create the function to manage the request to watsonx.ai + # manage_request = self._make_request_manager( + # async_=(acompletion is True), logging_obj=logging_obj + # ) + self.request_manager = RequestManager(logging_obj) + def handle_text_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=prompt, timeout=timeout, + with self.request_manager.request( + request_params, + input=prompt, + timeout=timeout, ) as resp: json_resp = resp.json() return process_text_gen_response(json_resp) async def handle_text_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=prompt, timeout=timeout, + async with self.request_manager.async_request( + request_params, + input=prompt, + timeout=timeout, ) as resp: json_resp = resp.json() return process_text_gen_response(json_resp) - def handle_stream_request( - request_params: dict, - ) -> litellm.CustomStreamWrapper: + def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, + with self.request_manager.request( + request_params, + stream=True, + input=prompt, + timeout=timeout, ) as resp: - streamwrapper = litellm.CustomStreamWrapper( - resp.iter_lines(), - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) + streamwrapper = process_stream_response(resp.iter_lines()) return streamwrapper - async def handle_stream_request_async( - request_params: dict, - ) -> litellm.CustomStreamWrapper: + async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - async with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, + async with self.request_manager.async_request( + request_params, + stream=True, + input=prompt, + timeout=timeout, ) as resp: - streamwrapper = litellm.CustomStreamWrapper( - resp.aiter_lines(), - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) + streamwrapper = process_stream_response(resp.aiter_lines()) return streamwrapper try: @@ -486,13 +508,13 @@ class IBMWatsonXAI(BaseLLM): optional_params=optional_params, print_verbose=print_verbose, ) - if stream and acompletion: + if stream and (acompletion is True): # stream and async text generation return handle_stream_request_async(req_params) elif stream: # streaming text generation return handle_stream_request(req_params) - elif acompletion: + elif (acompletion is True): # async text generation return handle_text_request_async(req_params) else: @@ -554,43 +576,48 @@ class IBMWatsonXAI(BaseLLM): "json": payload, "params": request_params, } - manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) - + # manage_request = self._make_request_manager( + # async_=(aembedding is True), logging_obj=logging_obj + # ) + request_manager = RequestManager(logging_obj) + def process_embedding_response(json_resp: dict) -> ModelResponse: results = json_resp.get("results", []) embedding_response = [] for idx, result in enumerate(results): embedding_response.append( - {"object": "embedding", "index": idx, "embedding": result["embedding"]} + { + "object": "embedding", + "index": idx, + "embedding": result["embedding"], + } ) model_response["object"] = "list" model_response["data"] = embedding_response model_response["model"] = model input_tokens = json_resp.get("input_token_count", 0) model_response.usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, ) return model_response - - def handle_embedding_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=input - ) as resp: + + def handle_embedding(request_params: dict) -> ModelResponse: + with request_manager.request(request_params, input=input) as resp: json_resp = resp.json() return process_embedding_response(json_resp) - - async def handle_embedding_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=input - ) as resp: + + async def handle_aembedding(request_params: dict) -> ModelResponse: + async with request_manager.async_request(request_params, input=input) as resp: json_resp = resp.json() return process_embedding_response(json_resp) - + try: - if aembedding: - return handle_embedding_request_async(req_params) + if aembedding is True: + return handle_embedding(req_params) else: - return handle_embedding_request(req_params) + return handle_aembedding(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -616,64 +643,88 @@ class IBMWatsonXAI(BaseLLM): iam_access_token = json_data["access_token"] self.token = iam_access_token return iam_access_token - - def _make_response_manager( - self, - async_: bool, - logging_obj: Logging - ) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: + + def get_available_models(self, *, ids_only: bool = True, **params): + api_params = self._get_api_params(params) + headers = { + "Authorization": f"Bearer {api_params['token']}", + "Content-Type": "application/json", + "Accept": "application/json", + } + request_params = dict(version=api_params["api_version"]) + url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS + req_params = dict(method="GET", url=url, headers=headers, params=request_params) + # manage_request = self._make_request_manager(async_=False, logging_obj=None) + with RequestManager(logging_obj=None).request(req_params) as resp: + json_resp = resp.json() + if not ids_only: + return json_resp + return [res["model_id"] for res in json_resp["resources"]] + + def _make_request_manager( + self, async_: bool, logging_obj=None + ) -> Callable[ + ..., + Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]], + ]: """ Returns a context manager that manages the response from the request. if async_ is True, returns an async context manager, otherwise returns a regular context manager. Usage: ```python - manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) - async with manage_response(request_params) as resp: + manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj) + async with manage_request(request_params) as resp: ... # or - manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) - with manage_response(request_params) as resp: + manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj) + with manage_request(request_params) as resp: ... ``` """ def pre_call( request_params: dict, - input:Optional[Any]=None, + input: Optional[Any] = None, ): + if logging_obj is None: + return request_str = ( f"response = {'await ' if async_ else ''}{request_params['method']}(\n" f"\turl={request_params['url']},\n" - f"\tjson={request_params['json']},\n" + f"\tjson={request_params.get('json')},\n" f")" ) logging_obj.pre_call( input=input, api_key=request_params["headers"].get("Authorization"), additional_args={ - "complete_input_dict": request_params["json"], + "complete_input_dict": request_params.get("json"), "request_str": request_str, }, ) def post_call(resp, request_params): + if logging_obj is None: + return logging_obj.post_call( input=input, api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params.get("data", request_params.get("json")), + "complete_input_dict": request_params.get( + "data", request_params.get("json") + ), }, ) - + @contextmanager - def _manage_response( + def _manage_request( request_params: dict, stream: bool = False, input: Optional[Any] = None, - timeout: float = None, + timeout=None, ) -> Generator[requests.Response, None, None]: """ Returns a context manager that yields the response from the request. @@ -685,20 +736,23 @@ class IBMWatsonXAI(BaseLLM): request_params["stream"] = stream try: resp = requests.request(**request_params) - resp.raise_for_status() + if not resp.ok: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) yield resp except Exception as e: raise WatsonXAIError(status_code=500, message=str(e)) if not stream: post_call(resp, request_params) - - + @asynccontextmanager - async def _manage_response_async( + async def _manage_request_async( request_params: dict, stream: bool = False, input: Optional[Any] = None, - timeout: float = None, + timeout=None, ) -> AsyncGenerator[httpx.Response, None]: pre_call(request_params, input) if timeout: @@ -708,16 +762,23 @@ class IBMWatsonXAI(BaseLLM): try: # async with AsyncHTTPHandler(timeout=timeout) as client: self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0), + timeout=httpx.Timeout( + timeout=request_params.pop("timeout", 600.0), connect=5.0 + ), ) # async_handler.client.verify = False if "json" in request_params: - request_params['data'] = json.dumps(request_params.pop("json", {})) + request_params["data"] = json.dumps(request_params.pop("json", {})) method = request_params.pop("method") if method.upper() == "POST": resp = await self.async_handler.post(**request_params) else: resp = await self.async_handler.get(**request_params) + if resp.status_code not in [200, 201]: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) yield resp # await async_handler.close() except Exception as e: @@ -726,6 +787,132 @@ class IBMWatsonXAI(BaseLLM): post_call(resp, request_params) if async_: - return _manage_response_async + return _manage_request_async else: - return _manage_response + return _manage_request + +class RequestManager: + """ + Returns a context manager that manages the response from the request. + if async_ is True, returns an async context manager, otherwise returns a regular context manager. + + Usage: + ```python + request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"}) + request_manager = RequestManager(logging_obj=logging_obj) + async with request_manager.request(request_params) as resp: + ... + # or + with request_manager.async_request(request_params) as resp: + ... + ``` + """ + + def __init__(self, logging_obj=None): + self.logging_obj = logging_obj + + def pre_call( + self, + request_params: dict, + input: Optional[Any] = None, + ): + if self.logging_obj is None: + return + request_str = ( + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params.get('json')},\n" + f")" + ) + self.logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params.get("json"), + "request_str": request_str, + }, + ) + + def post_call(self, resp, request_params): + if self.logging_obj is None: + return + self.logging_obj.post_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + original_response=json.dumps(resp.json()), + additional_args={ + "status_code": resp.status_code, + "complete_input_dict": request_params.get( + "data", request_params.get("json") + ), + }, + ) + + @contextmanager + def request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> Generator[requests.Response, None, None]: + """ + Returns a context manager that yields the response from the request. + """ + self.pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + resp = requests.request(**request_params) + if not resp.ok: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) + + @asynccontextmanager + async def async_request( + self, + request_params: dict, + stream: bool = False, + input: Optional[Any] = None, + timeout=None, + ) -> AsyncGenerator[httpx.Response, None]: + self.pre_call(request_params, input) + if timeout: + request_params["timeout"] = timeout + if stream: + request_params["stream"] = stream + try: + # async with AsyncHTTPHandler(timeout=timeout) as client: + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout( + timeout=request_params.pop("timeout", 600.0), connect=5.0 + ), + ) + # async_handler.client.verify = False + if "json" in request_params: + request_params["data"] = json.dumps(request_params.pop("json", {})) + method = request_params.pop("method") + if method.upper() == "POST": + resp = await self.async_handler.post(**request_params) + else: + resp = await self.async_handler.get(**request_params) + if resp.status_code not in [200, 201]: + raise WatsonXAIError( + status_code=resp.status_code, + message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + ) + yield resp + # await async_handler.close() + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + self.post_call(resp, request_params) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index d1af1b44a..3d1b0c1a7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10285,7 +10285,7 @@ class CustomStreamWrapper: response = chunk.replace("data: ", "").strip() parsed_response = json.loads(response) else: - return {"text": "", "is_finished": False} + return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0} else: print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") raise ValueError( @@ -10300,8 +10300,8 @@ class CustomStreamWrapper: "text": text, "is_finished": is_finished, "finish_reason": finish_reason, - "prompt_tokens": results[0].get("input_token_count", None), - "completion_tokens": results[0].get("generated_token_count", None), + "prompt_tokens": results[0].get("input_token_count", 0), + "completion_tokens": results[0].get("generated_token_count", 0), } return {"text": "", "is_finished": False} except Exception as e: From d3d82827edbe9c5c3840795f1385c915629b957d Mon Sep 17 00:00:00 2001 From: Simon Sanchez Viloria Date: Fri, 10 May 2024 11:55:58 +0200 Subject: [PATCH 2/2] (test) Add tests for WatsonX completion/acompletion streaming --- litellm/tests/test_completion.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 32b65faea..fa3e669f0 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3089,6 +3089,24 @@ def test_completion_watsonx(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_stream_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + try: + response = completion( + model=model_name, + messages=messages, + stop=["stop"], + max_tokens=20, + stream=True + ) + for chunk in response: + print(chunk) + except litellm.APIError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + @pytest.mark.parametrize( "provider, model, project, region_name, token", @@ -3153,6 +3171,25 @@ async def test_acompletion_watsonx(): except Exception as e: pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_acompletion_stream_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + print("testing watsonx") + try: + response = await litellm.acompletion( + model=model_name, + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True + ) + # Add any assertions here to check the response + async for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_palm_stream()