(feat) support for async stream to watsonx provider

This commit is contained in:
Simon Sanchez Viloria 2024-05-06 17:07:21 +02:00
parent 62b3f25398
commit 83a274b54b
3 changed files with 221 additions and 92 deletions

View file

@ -1,12 +1,13 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List
import httpx import httpx
import requests import requests
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import Logging, ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -173,14 +174,13 @@ class WatsonXAIEndpoint(str, Enum):
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM Watsonx.ai API for text generation and embeddings. Class to interface with IBM watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai Reference: https://cloud.ibm.com/apidocs/watsonx-ai
""" """
api_version = "2024-03-13" api_version = "2024-03-13"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -239,8 +239,7 @@ class IBMWatsonXAI(BaseLLM):
) )
url = api_params["url"].rstrip("/") + endpoint url = api_params["url"].rstrip("/") + endpoint
return dict( return dict(
method="POST", url=url, headers=headers, method="POST", url=url, headers=headers, json=payload, params=request_params
json=payload, params=request_params
) )
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
@ -307,7 +306,7 @@ class IBMWatsonXAI(BaseLLM):
) )
if token is None and api_key is not None: if token is None and api_key is not None:
# generate the auth token # generate the auth token
if print_verbose: if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai") print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key) token = self.generate_iam_token(api_key)
elif token is None and api_key is None: elif token is None and api_key is None:
@ -341,8 +340,9 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj: Logging,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
acompletion: bool = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
logger_fn=None, logger_fn=None,
timeout: float = None, timeout: float = None,
@ -366,12 +366,14 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
def process_text_request(request_params: dict) -> ModelResponse: manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
with self._manage_response(
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp:
json_resp = resp.json()
def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"] generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"] prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -386,25 +388,52 @@ class IBMWatsonXAI(BaseLLM):
) )
return model_response return model_response
def process_stream_request( def handle_text_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(
request_params: dict, request_params: dict,
) -> litellm.CustomStreamWrapper: ) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response( with manage_response(
request_params, request_params, stream=True, input=prompt, timeout=timeout,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
response = litellm.CustomStreamWrapper( streamwrapper = litellm.CustomStreamWrapper(
resp.iter_lines(), resp.iter_lines(),
model=model, model=model,
custom_llm_provider="watsonx", custom_llm_provider="watsonx",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return response return streamwrapper
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
) as resp:
streamwrapper = litellm.CustomStreamWrapper(
resp.aiter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
try: try:
## Get the response from the model ## Get the response from the model
@ -415,10 +444,18 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream: if stream and acompletion:
return process_stream_request(req_params) # stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif acompletion:
# async text generation
return handle_text_request_async(req_params)
else: else:
return process_text_request(req_params) # regular text generation
return handle_text_request(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -433,6 +470,7 @@ class IBMWatsonXAI(BaseLLM):
model_response=None, model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
aembedding=None,
): ):
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
@ -467,9 +505,6 @@ class IBMWatsonXAI(BaseLLM):
} }
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = { req_params = {
"method": "POST", "method": "POST",
"url": url, "url": url,
@ -477,11 +512,9 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
with self._manage_response( manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj)
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", []) results = json_resp.get("results", [])
embedding_response = [] embedding_response = []
for idx, result in enumerate(results): for idx, result in enumerate(results):
@ -497,6 +530,30 @@ class IBMWatsonXAI(BaseLLM):
) )
return model_response return model_response
def handle_embedding_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding:
return handle_embedding_request_async(req_params)
else:
return handle_embedding_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params): def generate_iam_token(self, api_key=None, **params):
headers = {} headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
@ -518,52 +575,115 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token self.token = iam_access_token
return iam_access_token return iam_access_token
@contextmanager def _make_response_manager(
def _manage_response(
self, self,
async_: bool,
logging_obj: Logging
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj)
async with manage_response(request_params) as resp:
...
# or
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj)
with manage_response(request_params) as resp:
...
```
"""
def pre_call(
request_params: dict, request_params: dict,
logging_obj: Any, input:Optional[Any]=None,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
): ):
request_str = ( request_str = (
f"response = {request_params['method']}(\n" f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n" f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n" f"\tjson={request_params['json']},\n"
f")" f")"
) )
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=request_params['headers'].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
additional_args={ additional_args={
"complete_input_dict": request_params['json'], "complete_input_dict": request_params["json"],
"request_str": request_str, "request_str": request_str,
}, },
) )
if timeout:
request_params['timeout'] = timeout def post_call(resp, request_params):
try: logging_obj.post_call(
if stream: input=input,
resp = requests.request( api_key=request_params["headers"].get("Authorization"),
**request_params, original_response=json.dumps(resp.json()),
stream=True, additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get("data", request_params.get("json")),
},
) )
resp.raise_for_status()
yield resp @contextmanager
else: def _manage_response(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params) resp = requests.request(**request_params)
resp.raise_for_status() resp.raise_for_status()
yield resp yield resp
except Exception as e: except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e)) raise WatsonXAIError(status_code=500, message=str(e))
if not stream: if not stream:
logging_obj.post_call( post_call(resp, request_params)
input=input,
api_key=request_params['headers'].get("Authorization"),
original_response=json.dumps(resp.json()), @asynccontextmanager
additional_args={ async def _manage_response_async(
"status_code": resp.status_code, request_params: dict,
"complete_input_dict": request_params['json'], stream: bool = False,
}, input: Optional[Any] = None,
timeout: float = None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
) )
# async_handler.client.verify = False
if "json" in request_params:
request_params['data'] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_response_async
else:
return _manage_response

View file

@ -70,6 +70,7 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -105,6 +106,7 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -308,6 +310,7 @@ async def acompletion(
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1865,7 +1868,7 @@ def completion(
response = response response = response
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion( response = watsonxai.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -1876,6 +1879,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion,
timeout=timeout, timeout=timeout,
) )
if ( if (
@ -2528,6 +2532,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -2980,13 +2985,14 @@ def embedding(
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding( response = watsonxai.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding,
) )
else: else:
args = locals() args = locals()

View file

@ -10084,6 +10084,8 @@ class CustomStreamWrapper:
response_obj = self.handle_watsonx_stream(chunk) response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None: if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
@ -10497,6 +10499,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream: