(fix) watsonx.py: Fixed linting errors and make sure stream chunk always return usage

This commit is contained in:
Simon Sanchez Viloria 2024-05-10 11:53:33 +02:00
parent 66a1b581e5
commit 170fd11c82
2 changed files with 279 additions and 92 deletions

View file

@ -1,12 +1,25 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List from typing import (
Callable,
Dict,
Generator,
AsyncGenerator,
Iterator,
AsyncIterator,
Optional,
Any,
Union,
List,
ContextManager,
AsyncContextManager,
)
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.utils import Logging, ModelResponse, Usage, get_secret from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
@ -189,6 +202,7 @@ class WatsonXAIEndpoint(str, Enum):
) )
EMBEDDINGS = "/ml/v1/text/embeddings" EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts" PROMPTS = "/ml/v1/prompts"
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
@ -378,12 +392,12 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj: Logging, logging_obj,
optional_params: Optional[dict] = None, optional_params=None,
acompletion: bool = None, acompletion=None,
litellm_params: Optional[dict] = None, litellm_params=None,
logger_fn=None, logger_fn=None,
timeout: Optional[float] = None, timeout=None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -403,8 +417,6 @@ class IBMWatsonXAI(BaseLLM):
prompt = convert_messages_to_prompt( prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
def process_text_gen_response(json_resp: dict) -> ModelResponse: def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp: if "results" not in json_resp:
@ -419,62 +431,72 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
setattr( usage = Usage(
model_response, prompt_tokens=prompt_tokens,
"usage", completion_tokens=completion_tokens,
Usage( total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
setattr(model_response, "usage", usage)
return model_response return model_response
def process_stream_response(
stream_resp: Union[Iterator[str], AsyncIterator],
) -> litellm.CustomStreamWrapper:
streamwrapper = litellm.CustomStreamWrapper(
stream_resp,
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
# create the function to manage the request to watsonx.ai
# manage_request = self._make_request_manager(
# async_=(acompletion is True), logging_obj=logging_obj
# )
self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse: def handle_text_request(request_params: dict) -> ModelResponse:
with manage_response( with self.request_manager.request(
request_params, input=prompt, timeout=timeout, request_params,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
json_resp = resp.json() json_resp = resp.json()
return process_text_gen_response(json_resp) return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse: async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with manage_response( async with self.request_manager.async_request(
request_params, input=prompt, timeout=timeout, request_params,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
json_resp = resp.json() json_resp = resp.json()
return process_text_gen_response(json_resp) return process_text_gen_response(json_resp)
def handle_stream_request( def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with manage_response( with self.request_manager.request(
request_params, stream=True, input=prompt, timeout=timeout, request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
streamwrapper = litellm.CustomStreamWrapper( streamwrapper = process_stream_response(resp.iter_lines())
resp.iter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper return streamwrapper
async def handle_stream_request_async( async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with manage_response( async with self.request_manager.async_request(
request_params, stream=True, input=prompt, timeout=timeout, request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
streamwrapper = litellm.CustomStreamWrapper( streamwrapper = process_stream_response(resp.aiter_lines())
resp.aiter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper return streamwrapper
try: try:
@ -486,13 +508,13 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream and acompletion: if stream and (acompletion is True):
# stream and async text generation # stream and async text generation
return handle_stream_request_async(req_params) return handle_stream_request_async(req_params)
elif stream: elif stream:
# streaming text generation # streaming text generation
return handle_stream_request(req_params) return handle_stream_request(req_params)
elif acompletion: elif (acompletion is True):
# async text generation # async text generation
return handle_text_request_async(req_params) return handle_text_request_async(req_params)
else: else:
@ -554,43 +576,48 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) # manage_request = self._make_request_manager(
# async_=(aembedding is True), logging_obj=logging_obj
# )
request_manager = RequestManager(logging_obj)
def process_embedding_response(json_resp: dict) -> ModelResponse: def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", []) results = json_resp.get("results", [])
embedding_response = [] embedding_response = []
for idx, result in enumerate(results): for idx, result in enumerate(results):
embedding_response.append( embedding_response.append(
{"object": "embedding", "index": idx, "embedding": result["embedding"]} {
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0) input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage( model_response.usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
) )
return model_response return model_response
def handle_embedding_request(request_params: dict) -> ModelResponse: def handle_embedding(request_params: dict) -> ModelResponse:
with manage_response( with request_manager.request(request_params, input=input) as resp:
request_params, input=input
) as resp:
json_resp = resp.json() json_resp = resp.json()
return process_embedding_response(json_resp) return process_embedding_response(json_resp)
async def handle_embedding_request_async(request_params: dict) -> ModelResponse: async def handle_aembedding(request_params: dict) -> ModelResponse:
async with manage_response( async with request_manager.async_request(request_params, input=input) as resp:
request_params, input=input
) as resp:
json_resp = resp.json() json_resp = resp.json()
return process_embedding_response(json_resp) return process_embedding_response(json_resp)
try: try:
if aembedding: if aembedding is True:
return handle_embedding_request_async(req_params) return handle_embedding(req_params)
else: else:
return handle_embedding_request(req_params) return handle_aembedding(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -616,64 +643,88 @@ class IBMWatsonXAI(BaseLLM):
iam_access_token = json_data["access_token"] iam_access_token = json_data["access_token"]
self.token = iam_access_token self.token = iam_access_token
return iam_access_token return iam_access_token
def _make_response_manager( def get_available_models(self, *, ids_only: bool = True, **params):
self, api_params = self._get_api_params(params)
async_: bool, headers = {
logging_obj: Logging "Authorization": f"Bearer {api_params['token']}",
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: "Content-Type": "application/json",
"Accept": "application/json",
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
with RequestManager(logging_obj=None).request(req_params) as resp:
json_resp = resp.json()
if not ids_only:
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
def _make_request_manager(
self, async_: bool, logging_obj=None
) -> Callable[
...,
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
]:
""" """
Returns a context manager that manages the response from the request. Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager. if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage: Usage:
```python ```python
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
async with manage_response(request_params) as resp: async with manage_request(request_params) as resp:
... ...
# or # or
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
with manage_response(request_params) as resp: with manage_request(request_params) as resp:
... ...
``` ```
""" """
def pre_call( def pre_call(
request_params: dict, request_params: dict,
input:Optional[Any]=None, input: Optional[Any] = None,
): ):
if logging_obj is None:
return
request_str = ( request_str = (
f"response = {'await ' if async_ else ''}{request_params['method']}(\n" f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n" f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n" f"\tjson={request_params.get('json')},\n"
f")" f")"
) )
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
additional_args={ additional_args={
"complete_input_dict": request_params["json"], "complete_input_dict": request_params.get("json"),
"request_str": request_str, "request_str": request_str,
}, },
) )
def post_call(resp, request_params): def post_call(resp, request_params):
if logging_obj is None:
return
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request_params.get("data", request_params.get("json")), "complete_input_dict": request_params.get(
"data", request_params.get("json")
),
}, },
) )
@contextmanager @contextmanager
def _manage_response( def _manage_request(
request_params: dict, request_params: dict,
stream: bool = False, stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,
timeout: float = None, timeout=None,
) -> Generator[requests.Response, None, None]: ) -> Generator[requests.Response, None, None]:
""" """
Returns a context manager that yields the response from the request. Returns a context manager that yields the response from the request.
@ -685,20 +736,23 @@ class IBMWatsonXAI(BaseLLM):
request_params["stream"] = stream request_params["stream"] = stream
try: try:
resp = requests.request(**request_params) resp = requests.request(**request_params)
resp.raise_for_status() if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp yield resp
except Exception as e: except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e)) raise WatsonXAIError(status_code=500, message=str(e))
if not stream: if not stream:
post_call(resp, request_params) post_call(resp, request_params)
@asynccontextmanager @asynccontextmanager
async def _manage_response_async( async def _manage_request_async(
request_params: dict, request_params: dict,
stream: bool = False, stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,
timeout: float = None, timeout=None,
) -> AsyncGenerator[httpx.Response, None]: ) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input) pre_call(request_params, input)
if timeout: if timeout:
@ -708,16 +762,23 @@ class IBMWatsonXAI(BaseLLM):
try: try:
# async with AsyncHTTPHandler(timeout=timeout) as client: # async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0), timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
) )
# async_handler.client.verify = False # async_handler.client.verify = False
if "json" in request_params: if "json" in request_params:
request_params['data'] = json.dumps(request_params.pop("json", {})) request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method") method = request_params.pop("method")
if method.upper() == "POST": if method.upper() == "POST":
resp = await self.async_handler.post(**request_params) resp = await self.async_handler.post(**request_params)
else: else:
resp = await self.async_handler.get(**request_params) resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp yield resp
# await async_handler.close() # await async_handler.close()
except Exception as e: except Exception as e:
@ -726,6 +787,132 @@ class IBMWatsonXAI(BaseLLM):
post_call(resp, request_params) post_call(resp, request_params)
if async_: if async_:
return _manage_response_async return _manage_request_async
else: else:
return _manage_response return _manage_request
class RequestManager:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
async with request_manager.request(request_params) as resp:
...
# or
with request_manager.async_request(request_params) as resp:
...
```
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self,
request_params: dict,
input: Optional[Any] = None,
):
if self.logging_obj is None:
return
request_str = (
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)

View file

@ -10285,7 +10285,7 @@ class CustomStreamWrapper:
response = chunk.replace("data: ", "").strip() response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response) parsed_response = json.loads(response)
else: else:
return {"text": "", "is_finished": False} return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0}
else: else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError( raise ValueError(
@ -10300,8 +10300,8 @@ class CustomStreamWrapper:
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", None), "prompt_tokens": results[0].get("input_token_count", 0),
"completion_tokens": results[0].get("generated_token_count", None), "completion_tokens": results[0].get("generated_token_count", 0),
} }
return {"text": "", "is_finished": False} return {"text": "", "is_finished": False}
except Exception as e: except Exception as e: