Merge pull request #4586 from simonsanvil/main

Fix bugs with watsonx embedding/async endpoints
This commit is contained in:
Krish Dholakia 2024-07-07 20:20:39 -07:00 committed by GitHub
commit 40a045cb72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 272 additions and 181 deletions

View file

@ -1,5 +1,7 @@
from enum import Enum
import json, types, time # noqa: E401
import asyncio
from datetime import datetime
from enum import Enum
from contextlib import asynccontextmanager, contextmanager
from typing import (
Callable,
@ -285,7 +287,10 @@ class IBMWatsonXAI(BaseLLM):
)
def _get_api_params(
self, params: dict, print_verbose: Optional[Callable] = None
self,
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> dict:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
@ -365,7 +370,7 @@ class IBMWatsonXAI(BaseLLM):
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None:
if token is None and api_key is not None and generate_token:
# generate the auth token
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
@ -393,6 +398,35 @@ class IBMWatsonXAI(BaseLLM):
"api_version": api_version,
}
def _process_text_gen_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response["choices"][0]["message"]["content"] = generated_text
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
if json_resp.get("created_at"):
model_response["created"] = datetime.fromisoformat(
json_resp["created_at"]
).timestamp()
else:
model_response["created"] = int(time.time())
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
@ -406,7 +440,7 @@ class IBMWatsonXAI(BaseLLM):
acompletion=None,
litellm_params=None,
logger_fn=None,
timeout=None,
timeout=None
):
"""
Send a text generation request to the IBM Watsonx.ai API.
@ -426,27 +460,7 @@ class IBMWatsonXAI(BaseLLM):
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response["choices"][0]["message"]["content"] = generated_text
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
model_response["model"] = model
def process_stream_response(
stream_resp: Union[Iterator[str], AsyncIterator],
@ -470,7 +484,7 @@ class IBMWatsonXAI(BaseLLM):
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
return self._process_text_gen_response(json_resp, model_response)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with self.request_manager.async_request(
@ -479,7 +493,7 @@ class IBMWatsonXAI(BaseLLM):
timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
return self._process_text_gen_response(json_resp, model_response)
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
@ -493,7 +507,9 @@ class IBMWatsonXAI(BaseLLM):
streamwrapper = process_stream_response(resp.iter_lines())
return streamwrapper
async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with self.request_manager.async_request(
@ -520,7 +536,7 @@ class IBMWatsonXAI(BaseLLM):
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif (acompletion is True):
elif acompletion is True:
# async text generation
return handle_text_request_async(req_params)
else:
@ -531,6 +547,29 @@ class IBMWatsonXAI(BaseLLM):
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
)
return model_response
def embedding(
self,
model: str,
@ -540,7 +579,8 @@ class IBMWatsonXAI(BaseLLM):
model_response=None,
optional_params=None,
encoding=None,
aembedding=None,
print_verbose=None,
aembedding=None
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
@ -553,6 +593,8 @@ class IBMWatsonXAI(BaseLLM):
if k not in optional_params:
optional_params[k] = v
model_response['model'] = model
# Load auth variables from environment variables
if isinstance(input, str):
input = [input]
@ -584,43 +626,23 @@ class IBMWatsonXAI(BaseLLM):
}
request_manager = RequestManager(logging_obj)
def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
)
return model_response
def handle_embedding(request_params: dict) -> ModelResponse:
with request_manager.request(request_params, input=input) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
return self._process_embedding_response(json_resp, model_response)
async def handle_aembedding(request_params: dict) -> ModelResponse:
async with request_manager.async_request(request_params, input=input) as resp:
async with request_manager.async_request(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
return self._process_embedding_response(json_resp, model_response)
try:
if aembedding is True:
return handle_embedding(req_params)
else:
return handle_aembedding(req_params)
else:
return handle_embedding(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
@ -664,127 +686,135 @@ class IBMWatsonXAI(BaseLLM):
return [res["model_id"] for res in json_resp["resources"]]
class RequestManager:
"""
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
with request_manager.request(request_params) as resp:
...
# or
async with request_manager.async_request(request_params) as resp:
...
```
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self,
request_params: dict,
input: Optional[Any] = None,
is_async: Optional[bool] = False,
):
if self.logging_obj is None:
return
request_str = (
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
async with request_manager.request(request_params) as resp:
...
# or
with request_manager.async_request(request_params) as resp:
...
```
Returns a context manager that yields the response from the request.
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self,
request_params: dict,
input: Optional[Any] = None,
):
if self.logging_obj is None:
return
request_str = (
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input, is_async=True)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
retries = 0
while retries < 3:
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
if resp.status_code in [429, 503, 504, 520]:
# to handle rate limiting and service unavailable errors
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
await asyncio.sleep(2**retries)
retries += 1
else:
break
if resp.is_error:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise e
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)

View file

@ -108,6 +108,7 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
@ -152,6 +153,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -369,6 +371,7 @@ async def acompletion(
or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks"
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -2352,7 +2355,7 @@ def completion(
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -2364,6 +2367,7 @@ def completion(
encoding=encoding,
logging_obj=logging,
timeout=timeout, # type: ignore
acompletion=acompletion,
)
if (
"stream" in optional_params
@ -3030,6 +3034,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3537,13 +3542,14 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
)
else:
args = locals()

View file

@ -11,6 +11,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import embedding, completion, completion_cost
from unittest.mock import MagicMock, patch
litellm.set_verbose = False
@ -484,14 +485,67 @@ def test_mistral_embeddings():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="local test")
def test_watsonx_embeddings():
def mock_wx_embed_request(method:str, url:str, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"model_id": "ibm/slate-30m-english-rtrvr",
"created_at": "2024-01-01T00:00:00.00Z",
"results": [ {"embedding": [0.0]*254} ],
"input_token_count": 8
}
return mock_response
try:
litellm.set_verbose = True
response = litellm.embedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
)
with patch("requests.request", side_effect=mock_wx_embed_request):
response = litellm.embedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
token="secret-token"
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_watsonx_aembeddings():
def mock_async_client(*args, **kwargs):
mocked_client = MagicMock()
async def mock_send(request, *args, stream: bool = False, **kwags):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"model_id": "ibm/slate-30m-english-rtrvr",
"created_at": "2024-01-01T00:00:00.00Z",
"results": [ {"embedding": [0.0]*254} ],
"input_token_count": 8
}
mock_response.is_error = False
return mock_response
mocked_client.send = mock_send
return mocked_client
try:
litellm.set_verbose = True
with patch("httpx.AsyncClient", side_effect=mock_async_client):
response = await litellm.aembedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
token="secret-token"
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:

View file

@ -9736,6 +9736,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "predibase"
or self.custom_llm_provider == "databricks"
or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: