forked from phoenix/litellm-mirror
(fix - watsonx) Fixed issues with watsonx embedding/async endpoints
This commit is contained in:
parent
c7338f9798
commit
06e6f52358
2 changed files with 186 additions and 119 deletions
|
@ -1,5 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json, types, time # noqa: E401
|
import json, types, time # noqa: E401
|
||||||
|
import asyncio
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -393,6 +394,35 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _process_text_gen_response(
|
||||||
|
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||||
|
) -> ModelResponse:
|
||||||
|
if "results" not in json_resp:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||||
|
)
|
||||||
|
if model_response is None:
|
||||||
|
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||||
|
generated_text = json_resp["results"][0]["generated_text"]
|
||||||
|
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||||
|
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||||
|
model_response["choices"][0]["message"]["content"] = generated_text
|
||||||
|
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||||
|
if json_resp.get("created_at"):
|
||||||
|
model_response["created"] = datetime.fromisoformat(
|
||||||
|
json_resp["created_at"]
|
||||||
|
).timestamp()
|
||||||
|
else:
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -531,6 +561,29 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
|
||||||
|
if model_response is None:
|
||||||
|
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||||
|
results = json_resp.get("results", [])
|
||||||
|
embedding_response = []
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": result["embedding"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = embedding_response
|
||||||
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
|
model_response.usage = Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=input_tokens,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -664,127 +717,135 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
return [res["model_id"] for res in json_resp["resources"]]
|
return [res["model_id"] for res in json_resp["resources"]]
|
||||||
|
|
||||||
class RequestManager:
|
class RequestManager:
|
||||||
|
"""
|
||||||
|
Returns a context manager that manages the response from the request.
|
||||||
|
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
||||||
|
request_manager = RequestManager(logging_obj=logging_obj)
|
||||||
|
with request_manager.request(request_params) as resp:
|
||||||
|
...
|
||||||
|
# or
|
||||||
|
async with request_manager.async_request(request_params) as resp:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logging_obj=None):
|
||||||
|
self.logging_obj = logging_obj
|
||||||
|
|
||||||
|
def pre_call(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
is_async: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
if self.logging_obj is None:
|
||||||
|
return
|
||||||
|
request_str = (
|
||||||
|
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
|
||||||
|
f"\turl={request_params['url']},\n"
|
||||||
|
f"\tjson={request_params.get('json')},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
self.logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_params.get("json"),
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_call(self, resp, request_params):
|
||||||
|
if self.logging_obj is None:
|
||||||
|
return
|
||||||
|
self.logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
original_response=json.dumps(resp.json()),
|
||||||
|
additional_args={
|
||||||
|
"status_code": resp.status_code,
|
||||||
|
"complete_input_dict": request_params.get(
|
||||||
|
"data", request_params.get("json")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout=None,
|
||||||
|
) -> Generator[requests.Response, None, None]:
|
||||||
"""
|
"""
|
||||||
Returns a context manager that manages the response from the request.
|
Returns a context manager that yields the response from the request.
|
||||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
```python
|
|
||||||
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
|
||||||
request_manager = RequestManager(logging_obj=logging_obj)
|
|
||||||
async with request_manager.request(request_params) as resp:
|
|
||||||
...
|
|
||||||
# or
|
|
||||||
with request_manager.async_request(request_params) as resp:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
self.pre_call(request_params, input)
|
||||||
def __init__(self, logging_obj=None):
|
if timeout:
|
||||||
self.logging_obj = logging_obj
|
request_params["timeout"] = timeout
|
||||||
|
if stream:
|
||||||
def pre_call(
|
request_params["stream"] = stream
|
||||||
self,
|
try:
|
||||||
request_params: dict,
|
resp = requests.request(**request_params)
|
||||||
input: Optional[Any] = None,
|
if not resp.ok:
|
||||||
):
|
raise WatsonXAIError(
|
||||||
if self.logging_obj is None:
|
status_code=resp.status_code,
|
||||||
return
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
request_str = (
|
|
||||||
f"response = {request_params['method']}(\n"
|
|
||||||
f"\turl={request_params['url']},\n"
|
|
||||||
f"\tjson={request_params.get('json')},\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
self.logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": request_params.get("json"),
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_call(self, resp, request_params):
|
|
||||||
if self.logging_obj is None:
|
|
||||||
return
|
|
||||||
self.logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
original_response=json.dumps(resp.json()),
|
|
||||||
additional_args={
|
|
||||||
"status_code": resp.status_code,
|
|
||||||
"complete_input_dict": request_params.get(
|
|
||||||
"data", request_params.get("json")
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def request(
|
|
||||||
self,
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> Generator[requests.Response, None, None]:
|
|
||||||
"""
|
|
||||||
Returns a context manager that yields the response from the request.
|
|
||||||
"""
|
|
||||||
self.pre_call(request_params, input)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
resp = requests.request(**request_params)
|
|
||||||
if not resp.ok:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=resp.status_code,
|
|
||||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
|
||||||
)
|
|
||||||
yield resp
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
if not stream:
|
|
||||||
self.post_call(resp, request_params)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def async_request(
|
|
||||||
self,
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> AsyncGenerator[httpx.Response, None]:
|
|
||||||
self.pre_call(request_params, input)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
|
||||||
self.async_handler = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(
|
|
||||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# async_handler.client.verify = False
|
yield resp
|
||||||
if "json" in request_params:
|
except Exception as e:
|
||||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
method = request_params.pop("method")
|
if not stream:
|
||||||
|
self.post_call(resp, request_params)
|
||||||
|
|
||||||
|
async def async_request(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout=None,
|
||||||
|
) -> AsyncGenerator[httpx.Response, None]:
|
||||||
|
self.pre_call(request_params, input, is_async=True)
|
||||||
|
if timeout:
|
||||||
|
request_params["timeout"] = timeout
|
||||||
|
if stream:
|
||||||
|
request_params["stream"] = stream
|
||||||
|
try:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if "json" in request_params:
|
||||||
|
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||||
|
method = request_params.pop("method")
|
||||||
|
retries = 0
|
||||||
|
while retries < 3:
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
resp = await self.async_handler.post(**request_params)
|
resp = await self.async_handler.post(**request_params)
|
||||||
else:
|
else:
|
||||||
resp = await self.async_handler.get(**request_params)
|
resp = await self.async_handler.get(**request_params)
|
||||||
if resp.status_code not in [200, 201]:
|
if resp.status_code in [429, 503, 504, 520]:
|
||||||
raise WatsonXAIError(
|
# to handle rate limiting and service unavailable errors
|
||||||
status_code=resp.status_code,
|
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
|
||||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
await asyncio.sleep(2**retries)
|
||||||
)
|
retries += 1
|
||||||
yield resp
|
else:
|
||||||
# await async_handler.close()
|
break
|
||||||
except Exception as e:
|
if resp.is_error:
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
raise WatsonXAIError(
|
||||||
if not stream:
|
status_code=resp.status_code,
|
||||||
self.post_call(resp, request_params)
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
|
)
|
||||||
|
yield resp
|
||||||
|
# await async_handler.close()
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
if not stream:
|
||||||
|
self.post_call(resp, request_params)
|
||||||
|
|
|
@ -108,6 +108,7 @@ from .llms.databricks import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
@ -152,6 +153,7 @@ triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
|
watsonxai = IBMWatsonXAI()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -369,6 +371,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "bedrock"
|
or custom_llm_provider == "bedrock"
|
||||||
or custom_llm_provider == "databricks"
|
or custom_llm_provider == "databricks"
|
||||||
or custom_llm_provider == "clarifai"
|
or custom_llm_provider == "clarifai"
|
||||||
|
or custom_llm_provider == "watsonx"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -2352,7 +2355,7 @@ def completion(
|
||||||
response = response
|
response = response
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
response = watsonx.IBMWatsonXAI().completion(
|
response = watsonxai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
@ -2364,6 +2367,7 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
timeout=timeout, # type: ignore
|
timeout=timeout, # type: ignore
|
||||||
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -3030,6 +3034,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
or custom_llm_provider == "databricks"
|
or custom_llm_provider == "databricks"
|
||||||
|
or custom_llm_provider == "watsonx"
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3537,13 +3542,14 @@ def embedding(
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
response = watsonx.IBMWatsonXAI().embedding(
|
response = watsonxai.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue