Merge pull request #3546 from BerriAI/revert-3479-feature/watsonx-integration

Revert "Add support for async streaming to watsonx provider "
This commit is contained in:
Krish Dholakia 2024-05-09 07:44:32 -07:00 committed by GitHub
commit 4c8787f896
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 112 additions and 239 deletions

View file

@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
def ibm_granite_pt(messages: list): def ibm_granite_pt(messages: list):
""" """
IBM's Granite chat models uses the template: IBM's Granite models uses the template:
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
"pre_message": "<|user|>\n", "pre_message": "<|user|>\n",
"post_message": "\n", "post_message": "\n",
}, },
'assistant': { "assistant": {
'pre_message': '<|assistant|>\n', "pre_message": "<|assistant|>\n",
'post_message': '\n', "post_message": "\n",
}, },
}, },
final_prompt_value='<|assistant|>\n', ).strip()
)
### ANTHROPIC ### ### ANTHROPIC ###
@ -1525,9 +1524,24 @@ def prompt_factory(
return mistral_instruct_pt(messages=messages) return mistral_instruct_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model: elif "meta-llama/llama-3" in model and "instruct" in model:
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return hf_chat_template( return custom_prompt(
model="meta-llama/Meta-Llama-3-8B-Instruct", role_dict={
"system": {
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"user": {
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"assistant": {
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
},
messages=messages, messages=messages,
initial_prompt_value="<|begin_of_text|>",
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
) )
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:

View file

@ -1,13 +1,12 @@
from enum import Enum from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import asynccontextmanager, contextmanager from contextlib import contextmanager
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List from typing import Callable, Dict, Optional, Any, Union, List
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.utils import Logging, ModelResponse, Usage, get_secret from litellm.utils import ModelResponse, get_secret, Usage
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -193,7 +192,7 @@ class WatsonXAIEndpoint(str, Enum):
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM watsonx.ai API for text generation and embeddings. Class to interface with IBM Watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai Reference: https://cloud.ibm.com/apidocs/watsonx-ai
""" """
@ -344,7 +343,7 @@ class IBMWatsonXAI(BaseLLM):
) )
if token is None and api_key is not None: if token is None and api_key is not None:
# generate the auth token # generate the auth token
if print_verbose is not None: if print_verbose:
print_verbose("Generating IAM token for Watsonx.ai") print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key) token = self.generate_iam_token(api_key)
elif token is None and api_key is None: elif token is None and api_key is None:
@ -378,9 +377,8 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj: Logging, logging_obj,
optional_params: Optional[dict] = None, optional_params: dict,
acompletion: bool = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
logger_fn=None, logger_fn=None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -404,14 +402,12 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj) def process_text_request(request_params: dict) -> ModelResponse:
with self._manage_response(
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp:
json_resp = resp.json()
def process_text_gen_response(json_resp: dict) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
generated_text = json_resp["results"][0]["generated_text"] generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"] prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"]
@ -430,52 +426,25 @@ class IBMWatsonXAI(BaseLLM):
) )
return model_response return model_response
def handle_text_request(request_params: dict) -> ModelResponse: def process_stream_request(
with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=prompt, timeout=timeout,
) as resp:
json_resp = resp.json()
return process_text_gen_response(json_resp)
def handle_stream_request(
request_params: dict, request_params: dict,
) -> litellm.CustomStreamWrapper: ) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with manage_response( with self._manage_response(
request_params, stream=True, input=prompt, timeout=timeout, request_params,
logging_obj=logging_obj,
stream=True,
input=prompt,
timeout=timeout,
) as resp: ) as resp:
streamwrapper = litellm.CustomStreamWrapper( response = litellm.CustomStreamWrapper(
resp.iter_lines(), resp.iter_lines(),
model=model, model=model,
custom_llm_provider="watsonx", custom_llm_provider="watsonx",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return streamwrapper return response
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with manage_response(
request_params, stream=True, input=prompt, timeout=timeout,
) as resp:
streamwrapper = litellm.CustomStreamWrapper(
resp.aiter_lines(),
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
try: try:
## Get the response from the model ## Get the response from the model
@ -486,18 +455,10 @@ class IBMWatsonXAI(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream and acompletion: if stream:
# stream and async text generation return process_stream_request(req_params)
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif acompletion:
# async text generation
return handle_text_request_async(req_params)
else: else:
# regular text generation return process_text_request(req_params)
return handle_text_request(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -512,7 +473,6 @@ class IBMWatsonXAI(BaseLLM):
model_response=None, model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
aembedding=None,
): ):
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
@ -547,6 +507,9 @@ class IBMWatsonXAI(BaseLLM):
} }
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
# request = httpx.Request(
# "POST", url, headers=headers, json=payload, params=request_params
# )
req_params = { req_params = {
"method": "POST", "method": "POST",
"url": url, "url": url,
@ -554,47 +517,25 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) with self._manage_response(
req_params, logging_obj=logging_obj, input=input
) as resp:
json_resp = resp.json()
def process_embedding_response(json_resp: dict) -> ModelResponse: results = json_resp.get("results", [])
results = json_resp.get("results", []) embedding_response = []
embedding_response = [] for idx, result in enumerate(results):
for idx, result in enumerate(results): embedding_response.append(
embedding_response.append( {"object": "embedding", "index": idx, "embedding": result["embedding"]}
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
return model_response model_response["object"] = "list"
model_response["data"] = embedding_response
def handle_embedding_request(request_params: dict) -> ModelResponse: model_response["model"] = model
with manage_response( input_tokens = json_resp.get("input_token_count", 0)
request_params, input=input model_response.usage = Usage(
) as resp: prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
json_resp = resp.json() )
return process_embedding_response(json_resp) return model_response
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
async with manage_response(
request_params, input=input
) as resp:
json_resp = resp.json()
return process_embedding_response(json_resp)
try:
if aembedding:
return handle_embedding_request_async(req_params)
else:
return handle_embedding_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params): def generate_iam_token(self, api_key=None, **params):
headers = {} headers = {}
@ -617,115 +558,52 @@ class IBMWatsonXAI(BaseLLM):
self.token = iam_access_token self.token = iam_access_token
return iam_access_token return iam_access_token
def _make_response_manager( @contextmanager
self, def _manage_response(
async_: bool, self,
logging_obj: Logging request_params: dict,
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: logging_obj: Any,
""" stream: bool = False,
Returns a context manager that manages the response from the request. input: Optional[Any] = None,
if async_ is True, returns an async context manager, otherwise returns a regular context manager. timeout: Optional[float] = None,
):
Usage: request_str = (
```python f"response = {request_params['method']}(\n"
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) f"\turl={request_params['url']},\n"
async with manage_response(request_params) as resp: f"\tjson={request_params['json']},\n"
... f")"
# or )
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) logging_obj.pre_call(
with manage_response(request_params) as resp: input=input,
... api_key=request_params["headers"].get("Authorization"),
``` additional_args={
""" "complete_input_dict": request_params["json"],
"request_str": request_str,
def pre_call( },
request_params: dict, )
input:Optional[Any]=None, if timeout:
): request_params["timeout"] = timeout
request_str = ( try:
f"response = {'await ' if async_ else ''}{request_params['method']}(\n" if stream:
f"\turl={request_params['url']},\n" resp = requests.request(
f"\tjson={request_params['json']},\n" **request_params,
f")" stream=True,
) )
logging_obj.pre_call( resp.raise_for_status()
input=input, yield resp
api_key=request_params["headers"].get("Authorization"), else:
additional_args={ resp = requests.request(**request_params)
"complete_input_dict": request_params["json"], resp.raise_for_status()
"request_str": request_str, yield resp
}, except Exception as e:
) raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
def post_call(resp, request_params):
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=request_params["headers"].get("Authorization"), api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request_params.get("data", request_params.get("json")), "complete_input_dict": request_params["json"],
}, },
) )
@contextmanager
def _manage_response(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
resp.raise_for_status()
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
@asynccontextmanager
async def _manage_response_async(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout: float = None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params['data'] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_response_async
else:
return _manage_response

View file

@ -73,7 +73,6 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -110,7 +109,6 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -315,7 +313,6 @@ async def acompletion(
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1911,7 +1908,7 @@ def completion(
response = response response = response
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion( response = watsonx.IBMWatsonXAI().completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -1922,8 +1919,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, timeout=timeout, # type: ignore
timeout=timeout,
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -2576,7 +2572,6 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3034,14 +3029,13 @@ def embedding(
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
response = watsonxai.embedding( response = watsonx.IBMWatsonXAI().embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding,
) )
else: else:
args = locals() args = locals()

View file

@ -10567,18 +10567,6 @@ class CustomStreamWrapper:
elif self.custom_llm_provider == "watsonx": elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk) response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
@ -10983,7 +10971,6 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream: