Revert "Add support for async streaming to watsonx provider "

This commit is contained in:
Krish Dholakia 2024-05-09 07:44:15 -07:00 committed by GitHub
parent 66a1b581e5
commit 8015bc1c47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 112 additions and 239 deletions

View file

@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
def ibm_granite_pt(messages: list):
"""
IBM's Granite chat models uses the template:
IBM's Granite models uses the template:
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
"pre_message": "<|user|>\n",
"post_message": "\n",
},
'assistant': {
'pre_message': '<|assistant|>\n',
'post_message': '\n',
"assistant": {
"pre_message": "<|assistant|>\n",
"post_message": "\n",
},
},
final_prompt_value='<|assistant|>\n',
)
).strip()
### ANTHROPIC ###
@ -1525,9 +1524,24 @@ def prompt_factory(
return mistral_instruct_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model:
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return hf_chat_template(
model="meta-llama/Meta-Llama-3-8B-Instruct",
return custom_prompt(
role_dict={
"system": {
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"user": {
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"assistant": {
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
},
messages=messages,
initial_prompt_value="<|begin_of_text|>",
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
)
try:
if "meta-llama/llama-2" in model and "chat" in model:

View file

@ -1,13 +1,12 @@
from enum import Enum
import json, types, time # noqa: E401
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import Logging, ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import ModelResponse, get_secret, Usage
from .base import BaseLLM
from .prompt_templates import factory as ptf
@ -193,7 +192,7 @@ class WatsonXAIEndpoint(str, Enum):
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM watsonx.ai API for text generation and embeddings.
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
"""
@ -344,7 +343,7 @@ class IBMWatsonXAI(BaseLLM):
)
if token is None and api_key is not None:
# generate the auth token
if print_verbose is not None:
if print_verbose:
print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key)
elif token is None and api_key is None:
@ -378,9 +377,8 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj: Logging,
optional_params: Optional[dict] = None,
acompletion: bool = None,
logging_obj,
optional_params: dict,
litellm_params: Optional[dict] = None,
logger_fn=None,
timeout: Optional[float] = None,
@ -404,14 +402,12 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict
)
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
def process_text_request(request_params: dict) -> ModelResponse:
with self._manage_response(
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp:
json_resp = resp.json()
def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -430,52 +426,25 @@ class IBMWatsonXAI(BaseLLM):
)
return model_response
def handle_text_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(
def process_stream_request(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
with self._manage_response(
request_params,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = litellm.CustomStreamWrapper(
response = litellm.CustomStreamWrapper(
resp.iter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
) as resp:
streamwrapper = litellm.CustomStreamWrapper(
resp.aiter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
return response
try:
## Get the response from the model
@ -486,18 +455,10 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params,
print_verbose=print_verbose,
)
if stream and acompletion:
# stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif acompletion:
# async text generation
return handle_text_request_async(req_params)
if stream:
return process_stream_request(req_params)
else:
# regular text generation
return handle_text_request(req_params)
return process_text_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
@ -512,7 +473,6 @@ class IBMWatsonXAI(BaseLLM):
model_response=None,
optional_params=None,
encoding=None,
aembedding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
@ -547,6 +507,9 @@ class IBMWatsonXAI(BaseLLM):
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = {
"method": "POST",
"url": url,
@ -554,9 +517,11 @@ class IBMWatsonXAI(BaseLLM):
"json": payload,
"params": request_params,
}
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj)
with self._manage_response(
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
def process_embedding_response(json_resp: dict) -> ModelResponse:
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
@ -572,30 +537,6 @@ class IBMWatsonXAI(BaseLLM):
)
return model_response
def handle_embedding_request(request_params: dict) -> ModelResponse:
with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding:
return handle_embedding_request_async(req_params)
else:
return handle_embedding_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
@ -617,33 +558,17 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token
return iam_access_token
def _make_response_manager(
@contextmanager
def _manage_response(
self,
async_: bool,
logging_obj: Logging
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj)
async with manage_response(request_params) as resp:
...
# or
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj)
with manage_response(request_params) as resp:
...
```
"""
def pre_call(
request_params: dict,
logging_obj: Any,
stream: bool = False,
input: Optional[Any] = None,
timeout: Optional[float] = None,
):
request_str = (
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"response = {request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params['json']},\n"
f")"
@ -656,76 +581,29 @@ class IBMWatsonXAI(BaseLLM):
"request_str": request_str,
},
)
def post_call(resp, request_params):
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get("data", request_params.get("json")),
},
)
@contextmanager
def _manage_response(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
if stream:
resp = requests.request(
**request_params,
stream=True,
)
resp.raise_for_status()
yield resp
else:
resp = requests.request(**request_params)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
@asynccontextmanager
async def _manage_response_async(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params["json"],
},
)
# async_handler.client.verify = False
if "json" in request_params:
request_params['data'] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_response_async
else:
return _manage_response

View file

@ -73,7 +73,6 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -110,7 +109,6 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -315,7 +313,6 @@ async def acompletion(
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1911,7 +1908,7 @@ def completion(
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion(
response = watsonx.IBMWatsonXAI().completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -1922,8 +1919,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
"stream" in optional_params
@ -2576,7 +2572,6 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3034,14 +3029,13 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonxai.embedding(
response = watsonx.IBMWatsonXAI().embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
)
else:
args = locals()

View file

@ -10567,18 +10567,6 @@ class CustomStreamWrapper:
elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
@ -10983,7 +10971,6 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: