(fix - watsonx) Fixed issues with watsonx embedding/async endpoints

This commit is contained in:
Simon Sanchez Viloria 2024-07-07 17:48:25 +02:00
parent c7338f9798
commit 06e6f52358
2 changed files with 186 additions and 119 deletions

View file

@ -1,5 +1,6 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
import asyncio
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import ( from typing import (
Callable, Callable,
@ -393,6 +394,35 @@ class IBMWatsonXAI(BaseLLM):
"api_version": api_version, "api_version": api_version,
} }
def _process_text_gen_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response["choices"][0]["message"]["content"] = generated_text
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
if json_resp.get("created_at"):
model_response["created"] = datetime.fromisoformat(
json_resp["created_at"]
).timestamp()
else:
model_response["created"] = int(time.time())
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion( def completion(
self, self,
model: str, model: str,
@ -530,6 +560,29 @@ class IBMWatsonXAI(BaseLLM):
raise e raise e
except Exception as e: except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e)) raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
)
return model_response
def embedding( def embedding(
self, self,
@ -664,127 +717,135 @@ class IBMWatsonXAI(BaseLLM):
return [res["model_id"] for res in json_resp["resources"]] return [res["model_id"] for res in json_resp["resources"]]
class RequestManager: class RequestManager:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
with request_manager.request(request_params) as resp:
...
# or
async with request_manager.async_request(request_params) as resp:
...
```
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self,
request_params: dict,
input: Optional[Any] = None,
is_async: Optional[bool] = False,
):
if self.logging_obj is None:
return
request_str = (
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
""" """
Returns a context manager that manages the response from the request. Returns a context manager that yields the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
async with request_manager.request(request_params) as resp:
...
# or
with request_manager.async_request(request_params) as resp:
...
```
""" """
self.pre_call(request_params, input)
def __init__(self, logging_obj=None): if timeout:
self.logging_obj = logging_obj request_params["timeout"] = timeout
if stream:
def pre_call( request_params["stream"] = stream
self, try:
request_params: dict, resp = requests.request(**request_params)
input: Optional[Any] = None, if not resp.ok:
): raise WatsonXAIError(
if self.logging_obj is None: status_code=resp.status_code,
return message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
request_str = (
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
) )
# async_handler.client.verify = False yield resp
if "json" in request_params: except Exception as e:
request_params["data"] = json.dumps(request_params.pop("json", {})) raise WatsonXAIError(status_code=500, message=str(e))
method = request_params.pop("method") if not stream:
self.post_call(resp, request_params)
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input, is_async=True)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
retries = 0
while retries < 3:
if method.upper() == "POST": if method.upper() == "POST":
resp = await self.async_handler.post(**request_params) resp = await self.async_handler.post(**request_params)
else: else:
resp = await self.async_handler.get(**request_params) resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]: if resp.status_code in [429, 503, 504, 520]:
raise WatsonXAIError( # to handle rate limiting and service unavailable errors
status_code=resp.status_code, # see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", await asyncio.sleep(2**retries)
) retries += 1
yield resp else:
# await async_handler.close() break
except Exception as e: if resp.is_error:
raise WatsonXAIError(status_code=500, message=str(e)) raise WatsonXAIError(
if not stream: status_code=resp.status_code,
self.post_call(resp, request_params) message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise e
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)

View file

@ -108,6 +108,7 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
@ -152,6 +153,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -369,6 +371,7 @@ async def acompletion(
or custom_llm_provider == "bedrock" or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "clarifai" or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -2352,7 +2355,7 @@ def completion(
response = response response = response
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion( response = watsonxai.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -2364,6 +2367,7 @@ def completion(
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
timeout=timeout, # type: ignore timeout=timeout, # type: ignore
acompletion=acompletion,
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -3030,6 +3034,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3537,13 +3542,14 @@ def embedding(
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding( response = watsonxai.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding,
) )
else: else:
args = locals() args = locals()