(feat) support for async stream to watsonx provider

This commit is contained in:
Simon Sanchez Viloria 2024-05-06 17:07:21 +02:00
parent 62b3f25398
commit 83a274b54b
3 changed files with 221 additions and 92 deletions

View file

@ -1,12 +1,13 @@
from enum import Enum
import json, types, time # noqa: E401
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List
import httpx
import requests
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from litellm.utils import Logging, ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
from .prompt_templates import factory as ptf
@ -173,14 +174,13 @@ class WatsonXAIEndpoint(str, Enum):
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
Class to interface with IBM watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
"""
api_version = "2024-03-13"
def __init__(self) -> None:
super().__init__()
@ -239,8 +239,7 @@ class IBMWatsonXAI(BaseLLM):
)
url = api_params["url"].rstrip("/") + endpoint
return dict(
method="POST", url=url, headers=headers,
json=payload, params=request_params
method="POST", url=url, headers=headers, json=payload, params=request_params
)
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
@ -307,7 +306,7 @@ class IBMWatsonXAI(BaseLLM):
)
if token is None and api_key is not None:
# generate the auth token
if print_verbose:
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key)
elif token is None and api_key is None:
@ -341,8 +340,9 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
logging_obj: Logging,
optional_params: Optional[dict] = None,
acompletion: bool = None,
litellm_params: Optional[dict] = None,
logger_fn=None,
timeout: float = None,
@ -366,12 +366,14 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict
)
def process_text_request(request_params: dict) -> ModelResponse:
with self._manage_response(
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp:
json_resp = resp.json()
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -386,25 +388,52 @@ class IBMWatsonXAI(BaseLLM):
)
return model_response
def process_stream_request(
def handle_text_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response(
request_params,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
) as resp:
response = litellm.CustomStreamWrapper(
streamwrapper = litellm.CustomStreamWrapper(
resp.iter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return response
return streamwrapper
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
) as resp:
streamwrapper = litellm.CustomStreamWrapper(
resp.aiter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
try:
## Get the response from the model
@ -415,10 +444,18 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params,
print_verbose=print_verbose,
)
if stream:
return process_stream_request(req_params)
if stream and acompletion:
# stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif acompletion:
# async text generation
return handle_text_request_async(req_params)
else:
return process_text_request(req_params)
# regular text generation
return handle_text_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
@ -433,6 +470,7 @@ class IBMWatsonXAI(BaseLLM):
model_response=None,
optional_params=None,
encoding=None,
aembedding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
@ -467,9 +505,6 @@ class IBMWatsonXAI(BaseLLM):
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = {
"method": "POST",
"url": url,
@ -477,11 +512,9 @@ class IBMWatsonXAI(BaseLLM):
"json": payload,
"params": request_params,
}
with self._manage_response(
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj)
def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
@ -497,6 +530,30 @@ class IBMWatsonXAI(BaseLLM):
)
return model_response
def handle_embedding_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding:
return handle_embedding_request_async(req_params)
else:
return handle_embedding_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
@ -518,52 +575,115 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token
return iam_access_token
@contextmanager
def _manage_response(
def _make_response_manager(
self,
async_: bool,
logging_obj: Logging
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj)
async with manage_response(request_params) as resp:
...
# or
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj)
with manage_response(request_params) as resp:
...
```
"""
def pre_call(
request_params: dict,
logging_obj: Any,
stream: bool = False,
input:Optional[Any]=None,
timeout: float = None,
):
request_str = (
f"response = {request_params['method']}(\n"
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n"
f")"
)
logging_obj.pre_call(
input=input,
api_key=request_params['headers'].get("Authorization"),
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params['json'],
"complete_input_dict": request_params["json"],
"request_str": request_str,
},
)
if timeout:
request_params['timeout'] = timeout
try:
if stream:
resp = requests.request(
**request_params,
stream=True,
def post_call(resp, request_params):
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get("data", request_params.get("json")),
},
)
resp.raise_for_status()
yield resp
else:
@contextmanager
def _manage_response(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
logging_obj.post_call(
input=input,
api_key=request_params['headers'].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params['json'],
},
post_call(resp, request_params)
@asynccontextmanager
async def _manage_response_async(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params['data'] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_response_async
else:
return _manage_response

View file

@ -70,6 +70,7 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -105,6 +106,7 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -308,6 +310,7 @@ async def acompletion(
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1865,7 +1868,7 @@ def completion(
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -1876,6 +1879,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
)
if (
@ -2528,6 +2532,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -2980,13 +2985,14 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
)
else:
args = locals()

View file

@ -10084,6 +10084,8 @@ class CustomStreamWrapper:
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
@ -10497,6 +10499,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: