forked from phoenix/litellm-mirror
Merge pull request #3546 from BerriAI/revert-3479-feature/watsonx-integration
Revert "Add support for async streaming to watsonx provider "
This commit is contained in:
commit
4c8787f896
4 changed files with 112 additions and 239 deletions
|
@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
|||
|
||||
def ibm_granite_pt(messages: list):
|
||||
"""
|
||||
IBM's Granite chat models uses the template:
|
||||
IBM's Granite models uses the template:
|
||||
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
||||
|
||||
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
||||
|
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
|
|||
"pre_message": "<|user|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
'assistant': {
|
||||
'pre_message': '<|assistant|>\n',
|
||||
'post_message': '\n',
|
||||
"assistant": {
|
||||
"pre_message": "<|assistant|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
},
|
||||
final_prompt_value='<|assistant|>\n',
|
||||
)
|
||||
).strip()
|
||||
|
||||
|
||||
### ANTHROPIC ###
|
||||
|
@ -1525,9 +1524,24 @@ def prompt_factory(
|
|||
return mistral_instruct_pt(messages=messages)
|
||||
elif "meta-llama/llama-3" in model and "instruct" in model:
|
||||
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
||||
return hf_chat_template(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
return custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
},
|
||||
messages=messages,
|
||||
initial_prompt_value="<|begin_of_text|>",
|
||||
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
|
||||
)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from enum import Enum
|
||||
import json, types, time # noqa: E401
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Any, Union, List
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
from litellm.utils import Logging, ModelResponse, Usage, get_secret
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
@ -193,7 +192,7 @@ class WatsonXAIEndpoint(str, Enum):
|
|||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||
"""
|
||||
|
@ -344,7 +343,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
if token is None and api_key is not None:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
if print_verbose:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = self.generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
|
@ -378,9 +377,8 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj: Logging,
|
||||
optional_params: Optional[dict] = None,
|
||||
acompletion: bool = None,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -404,14 +402,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
|
||||
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
|
||||
def process_text_request(request_params: dict) -> ModelResponse:
|
||||
with self._manage_response(
|
||||
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
def process_text_gen_response(json_resp: dict) -> ModelResponse:
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
|
@ -430,52 +426,25 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
return model_response
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
with manage_response(
|
||||
request_params, input=prompt, timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
||||
async with manage_response(
|
||||
request_params, input=prompt, timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_text_gen_response(json_resp)
|
||||
|
||||
def handle_stream_request(
|
||||
def process_stream_request(
|
||||
request_params: dict,
|
||||
) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
with manage_response(
|
||||
request_params, stream=True, input=prompt, timeout=timeout,
|
||||
with self._manage_response(
|
||||
request_params,
|
||||
logging_obj=logging_obj,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
response = litellm.CustomStreamWrapper(
|
||||
resp.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def handle_stream_request_async(
|
||||
request_params: dict,
|
||||
) -> litellm.CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
async with manage_response(
|
||||
request_params, stream=True, input=prompt, timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
resp.aiter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
return response
|
||||
|
||||
try:
|
||||
## Get the response from the model
|
||||
|
@ -486,18 +455,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
optional_params=optional_params,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if stream and acompletion:
|
||||
# stream and async text generation
|
||||
return handle_stream_request_async(req_params)
|
||||
elif stream:
|
||||
# streaming text generation
|
||||
return handle_stream_request(req_params)
|
||||
elif acompletion:
|
||||
# async text generation
|
||||
return handle_text_request_async(req_params)
|
||||
if stream:
|
||||
return process_stream_request(req_params)
|
||||
else:
|
||||
# regular text generation
|
||||
return handle_text_request(req_params)
|
||||
return process_text_request(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
@ -512,7 +473,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
aembedding=None,
|
||||
):
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
|
@ -547,6 +507,9 @@ class IBMWatsonXAI(BaseLLM):
|
|||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||
# request = httpx.Request(
|
||||
# "POST", url, headers=headers, json=payload, params=request_params
|
||||
# )
|
||||
req_params = {
|
||||
"method": "POST",
|
||||
"url": url,
|
||||
|
@ -554,9 +517,11 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj)
|
||||
with self._manage_response(
|
||||
req_params, logging_obj=logging_obj, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
|
@ -572,30 +537,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
)
|
||||
return model_response
|
||||
|
||||
def handle_embedding_request(request_params: dict) -> ModelResponse:
|
||||
with manage_response(
|
||||
request_params, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
|
||||
async with manage_response(
|
||||
request_params, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return process_embedding_response(json_resp)
|
||||
|
||||
try:
|
||||
if aembedding:
|
||||
return handle_embedding_request_async(req_params)
|
||||
else:
|
||||
return handle_embedding_request(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def generate_iam_token(self, api_key=None, **params):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
@ -617,33 +558,17 @@ class IBMWatsonXAI(BaseLLM):
|
|||
self.token = iam_access_token
|
||||
return iam_access_token
|
||||
|
||||
def _make_response_manager(
|
||||
@contextmanager
|
||||
def _manage_response(
|
||||
self,
|
||||
async_: bool,
|
||||
logging_obj: Logging
|
||||
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj)
|
||||
async with manage_response(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj)
|
||||
with manage_response(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def pre_call(
|
||||
request_params: dict,
|
||||
logging_obj: Any,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
request_str = (
|
||||
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
||||
f"response = {request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params['json']},\n"
|
||||
f")"
|
||||
|
@ -656,76 +581,29 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
def post_call(resp, request_params):
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params.get("data", request_params.get("json")),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _manage_response(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: float = None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
if stream:
|
||||
resp = requests.request(
|
||||
**request_params,
|
||||
stream=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
else:
|
||||
resp = requests.request(**request_params)
|
||||
resp.raise_for_status()
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _manage_response_async(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout: float = None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params["json"],
|
||||
},
|
||||
)
|
||||
# async_handler.client.verify = False
|
||||
if "json" in request_params:
|
||||
request_params['data'] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
if async_:
|
||||
return _manage_response_async
|
||||
else:
|
||||
return _manage_response
|
||||
|
|
|
@ -73,7 +73,6 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -110,7 +109,6 @@ anthropic_text_completions = AnthropicTextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -315,7 +313,6 @@ async def acompletion(
|
|||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -1911,7 +1908,7 @@ def completion(
|
|||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonxai.completion(
|
||||
response = watsonx.IBMWatsonXAI().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
|
@ -1922,8 +1919,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -2576,7 +2572,6 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "watsonx"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3034,14 +3029,13 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonxai.embedding(
|
||||
response = watsonx.IBMWatsonXAI().embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
|
|
|
@ -10567,18 +10567,6 @@ class CustomStreamWrapper:
|
|||
elif self.custom_llm_provider == "watsonx":
|
||||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if getattr(model_response, "usage", None) is None:
|
||||
model_response.usage = Usage()
|
||||
if response_obj.get("prompt_tokens") is not None:
|
||||
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
||||
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
||||
if response_obj.get("completion_tokens") is not None:
|
||||
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "prompt_tokens", 0)
|
||||
+ getattr(model_response.usage, "completion_tokens", 0)
|
||||
)
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
|
@ -10983,7 +10971,6 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue