forked from phoenix/litellm-mirror
Merge pull request #3561 from simonsanvil/feature/watsonx-integration
(fix) Fixed linting and other bugs with watsonx provider
This commit is contained in:
commit
d33e49411d
3 changed files with 310 additions and 101 deletions
|
@ -1,12 +1,26 @@
|
|||
from enum import Enum
|
||||
import json, types, time # noqa: E401
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Any, Union, List
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
AsyncGenerator,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
Optional,
|
||||
Any,
|
||||
Union,
|
||||
List,
|
||||
ContextManager,
|
||||
AsyncContextManager,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from litellm.utils import ModelResponse, Usage, get_secret
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
@ -188,11 +202,12 @@ class WatsonXAIEndpoint(str, Enum):
|
|||
)
|
||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||
"""
|
||||
|
@ -343,7 +358,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
if token is None and api_key is not None:
|
||||
# generate the auth token
|
||||
if print_verbose:
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = self.generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
|
@ -378,10 +393,11 @@ class IBMWatsonXAI(BaseLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
optional_params=None,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[float] = None,
|
||||
timeout=None,
|
||||
):
|
||||
"""
|
||||
Send a text generation request to the IBM Watsonx.ai API.
|
||||
|
@ -402,12 +418,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
|
||||
def process_text_request(request_params: dict) -> ModelResponse:
|
||||
with self._manage_response(
|
||||
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
def process_text_gen_response(json_resp: dict) -> ModelResponse:
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
|
@ -415,36 +431,70 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def process_stream_request(
|
||||
request_params: dict,
|
||||
def process_stream_response(
|
||||
stream_resp: Union[Iterator[str], AsyncIterator],
|
||||
) -> litellm.CustomStreamWrapper:
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
stream_resp,
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
# create the function to manage the request to watsonx.ai
|
||||
self.request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
with self._manage_response(
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
logging_obj=logging_obj,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
response = litellm.CustomStreamWrapper(
|
||||
resp.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return response
|
||||
streamwrapper = process_stream_response(resp.iter_lines())
|
||||
return streamwrapper
|
||||
|
||||
async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = process_stream_response(resp.aiter_lines())
|
||||
return streamwrapper
|
||||
|
||||
try:
|
||||
## Get the response from the model
|
||||
|
@ -455,10 +505,18 @@ class IBMWatsonXAI(BaseLLM):
|
|||
optional_params=optional_params,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if stream:
|
||||
return process_stream_request(req_params)
|
||||
if stream and (acompletion is True):
|
||||
# stream and async text generation
|
||||
return handle_stream_request_async(req_params)
|
||||
elif stream:
|
||||
# streaming text generation
|
||||
return handle_stream_request(req_params)
|
||||
elif (acompletion is True):
|
||||
# async text generation
|
||||
return handle_text_request_async(req_params)
|
||||
else:
|
||||
return process_text_request(req_params)
|
||||
# regular text generation
|
||||
return handle_text_request(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
@ -473,6 +531,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
aembedding=None,
|
||||
):
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
|
@ -507,9 +566,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||
# request = httpx.Request(
|
||||
# "POST", url, headers=headers, json=payload, params=request_params
|
||||
# )
|
||||
req_params = {
|
||||
"method": "POST",
|
||||
"url": url,
|
||||
|
@ -517,25 +573,49 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
with self._manage_response(
|
||||
req_params, logging_obj=logging_obj, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
request_manager = RequestManager(logging_obj)
|
||||
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
|
||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": result["embedding"],
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
return model_response
|
||||
return model_response
|
||||
|
||||
def handle_embedding(request_params: dict) -> ModelResponse:
|
||||
with request_manager.request(request_params, input=input) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
async def handle_aembedding(request_params: dict) -> ModelResponse:
|
||||
async with request_manager.async_request(request_params, input=input) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
try:
|
||||
if aembedding is True:
|
||||
return handle_embedding(req_params)
|
||||
else:
|
||||
return handle_aembedding(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def generate_iam_token(self, api_key=None, **params):
|
||||
headers = {}
|
||||
|
@ -558,52 +638,144 @@ class IBMWatsonXAI(BaseLLM):
|
|||
self.token = iam_access_token
|
||||
return iam_access_token
|
||||
|
||||
@contextmanager
|
||||
def _manage_response(
|
||||
self,
|
||||
request_params: dict,
|
||||
logging_obj: Any,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
request_str = (
|
||||
f"response = {request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params['json']},\n"
|
||||
f")"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params["json"],
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
try:
|
||||
if stream:
|
||||
resp = requests.request(
|
||||
**request_params,
|
||||
stream=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
else:
|
||||
resp = requests.request(**request_params)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
logging_obj.post_call(
|
||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||
api_params = self._get_api_params(params)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||
json_resp = resp.json()
|
||||
if not ids_only:
|
||||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
||||
request_manager = RequestManager(logging_obj=logging_obj)
|
||||
async with request_manager.request(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
with request_manager.async_request(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, logging_obj=None):
|
||||
self.logging_obj = logging_obj
|
||||
|
||||
def pre_call(
|
||||
self,
|
||||
request_params: dict,
|
||||
input: Optional[Any] = None,
|
||||
):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
request_str = (
|
||||
f"response = {request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params.get('json')},\n"
|
||||
f")"
|
||||
)
|
||||
self.logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params.get("json"),
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
def post_call(self, resp, request_params):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
self.logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params["json"],
|
||||
"complete_input_dict": request_params.get(
|
||||
"data", request_params.get("json")
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
self.pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
resp = requests.request(**request_params)
|
||||
if not resp.ok:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
self.pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(
|
||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||
),
|
||||
)
|
||||
# async_handler.client.verify = False
|
||||
if "json" in request_params:
|
||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp.status_code not in [200, 201]:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
|
@ -3236,6 +3236,24 @@ def test_completion_watsonx():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
stop=["stop"],
|
||||
max_tokens=20,
|
||||
stream=True
|
||||
)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider, model, project, region_name, token",
|
||||
|
@ -3300,6 +3318,25 @@ async def test_acompletion_watsonx():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
print("testing watsonx")
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
stream=True
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_palm_stream()
|
||||
|
||||
|
|
|
@ -10430,7 +10430,7 @@ class CustomStreamWrapper:
|
|||
response = chunk.replace("data: ", "").strip()
|
||||
parsed_response = json.loads(response)
|
||||
else:
|
||||
return {"text": "", "is_finished": False}
|
||||
return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0}
|
||||
else:
|
||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||
raise ValueError(
|
||||
|
@ -10445,8 +10445,8 @@ class CustomStreamWrapper:
|
|||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": results[0].get("input_token_count", None),
|
||||
"completion_tokens": results[0].get("generated_token_count", None),
|
||||
"prompt_tokens": results[0].get("input_token_count", 0),
|
||||
"completion_tokens": results[0].get("generated_token_count", 0),
|
||||
}
|
||||
return {"text": "", "is_finished": False}
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue