(feat) make manage_response work with request.request instead of httpx.Request

This commit is contained in:
Simon Sanchez Viloria 2024-04-24 12:55:25 +02:00
parent 9fc30e8b31
commit 777b4b2bbc

View file

@ -1,3 +1,4 @@
from enum import Enum
import json, types, time # noqa: E401 import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import Callable, Dict, Optional, Any, Union, List
@ -160,6 +161,15 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
) )
return prompt return prompt
class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
DEPLOYMENT_TEXT_GENERATION_STREAM = (
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
)
EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
@ -169,14 +179,7 @@ class IBMWatsonXAI(BaseLLM):
""" """
api_version = "2024-03-13" api_version = "2024-03-13"
_text_gen_endpoint = "/ml/v1/text/generation"
_text_gen_stream_endpoint = "/ml/v1/text/generation_stream"
_deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation"
_deployment_text_gen_stream_endpoint = (
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
)
_embeddings_endpoint = "/ml/v1/text/embeddings"
_prompts_endpoint = "/ml/v1/prompts"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -188,7 +191,7 @@ class IBMWatsonXAI(BaseLLM):
stream: bool, stream: bool,
optional_params: dict, optional_params: dict,
print_verbose: Callable = None, print_verbose: Callable = None,
) -> httpx.Request: ) -> dict:
""" """
Get the request parameters for text generation. Get the request parameters for text generation.
""" """
@ -221,20 +224,23 @@ class IBMWatsonXAI(BaseLLM):
) )
deployment_id = "/".join(model_id.split("/")[1:]) deployment_id = "/".join(model_id.split("/")[1:])
endpoint = ( endpoint = (
self._deployment_text_gen_stream_endpoint WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM
if stream if stream
else self._deployment_text_gen_endpoint else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION
) )
endpoint = endpoint.format(deployment_id=deployment_id) endpoint = endpoint.format(deployment_id=deployment_id)
else: else:
payload["model_id"] = model_id payload["model_id"] = model_id
payload["project_id"] = api_params["project_id"] payload["project_id"] = api_params["project_id"]
endpoint = ( endpoint = (
self._text_gen_stream_endpoint if stream else self._text_gen_endpoint WatsonXAIEndpoint.TEXT_GENERATION_STREAM
if stream
else WatsonXAIEndpoint.TEXT_GENERATION
) )
url = api_params["url"].rstrip("/") + endpoint url = api_params["url"].rstrip("/") + endpoint
return httpx.Request( return dict(
"POST", url, headers=headers, json=payload, params=request_params method="POST", url=url, headers=headers,
json=payload, params=request_params
) )
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
@ -360,9 +366,9 @@ class IBMWatsonXAI(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
def process_text_request(request: httpx.Request) -> ModelResponse: def process_text_request(request_params: dict) -> ModelResponse:
with self._manage_response( with self._manage_response(
request, logging_obj=logging_obj, input=prompt, timeout=timeout request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
) as resp: ) as resp:
json_resp = resp.json() json_resp = resp.json()
@ -381,12 +387,12 @@ class IBMWatsonXAI(BaseLLM):
return model_response return model_response
def process_stream_request( def process_stream_request(
request: httpx.Request, request_params: dict,
) -> litellm.CustomStreamWrapper: ) -> litellm.CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self._manage_response( with self._manage_response(
request, request_params,
logging_obj=logging_obj, logging_obj=logging_obj,
stream=True, stream=True,
input=prompt, input=prompt,
@ -402,7 +408,7 @@ class IBMWatsonXAI(BaseLLM):
try: try:
## Get the response from the model ## Get the response from the model
request = self._prepare_text_generation_req( req_params = self._prepare_text_generation_req(
model_id=model, model_id=model,
prompt=prompt, prompt=prompt,
stream=stream, stream=stream,
@ -410,9 +416,9 @@ class IBMWatsonXAI(BaseLLM):
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if stream: if stream:
return process_stream_request(request) return process_stream_request(req_params)
else: else:
return process_text_request(request) return process_text_request(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
raise e raise e
except Exception as e: except Exception as e:
@ -460,12 +466,19 @@ class IBMWatsonXAI(BaseLLM):
"parameters": optional_params, "parameters": optional_params,
} }
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + self._embeddings_endpoint url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
request = httpx.Request( # request = httpx.Request(
"POST", url, headers=headers, json=payload, params=request_params # "POST", url, headers=headers, json=payload, params=request_params
) # )
req_params = {
"method": "POST",
"url": url,
"headers": headers,
"json": payload,
"params": request_params,
}
with self._manage_response( with self._manage_response(
request, logging_obj=logging_obj, input=input req_params, logging_obj=logging_obj, input=input
) as resp: ) as resp:
json_resp = resp.json() json_resp = resp.json()
@ -508,48 +521,38 @@ class IBMWatsonXAI(BaseLLM):
@contextmanager @contextmanager
def _manage_response( def _manage_response(
self, self,
request: httpx.Request, request_params: dict,
logging_obj: Any, logging_obj: Any,
stream: bool = False, stream: bool = False,
input: Optional[Any] = None, input: Optional[Any] = None,
timeout: float = None, timeout: float = None,
): ):
request_str = ( request_str = (
f"response = {request.method}(\n" f"response = {request_params['method']}(\n"
f"\turl={request.url},\n" f"\turl={request_params['url']},\n"
f"\tjson={request.content.decode()},\n" f"\tjson={request_params['json']},\n"
f")" f")"
) )
json_input = json.loads(request.content.decode())
headers = dict(request.headers)
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=request.headers.get("Authorization"), api_key=request_params['headers'].get("Authorization"),
additional_args={ additional_args={
"complete_input_dict": json_input, "complete_input_dict": request_params['json'],
"request_str": request_str, "request_str": request_str,
}, },
) )
if timeout:
request_params['timeout'] = timeout
try: try:
if stream: if stream:
resp = requests.request( resp = requests.request(
method=request.method, **request_params,
url=str(request.url),
headers=headers,
json=json_input,
stream=True, stream=True,
timeout=timeout,
) )
# resp.raise_for_status() resp.raise_for_status()
yield resp yield resp
else: else:
resp = requests.request( resp = requests.request(**request_params)
method=request.method,
url=str(request.url),
headers=headers,
json=json_input,
timeout=timeout,
)
resp.raise_for_status() resp.raise_for_status()
yield resp yield resp
except Exception as e: except Exception as e:
@ -557,10 +560,10 @@ class IBMWatsonXAI(BaseLLM):
if not stream: if not stream:
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=request.headers.get("Authorization"), api_key=request_params['headers'].get("Authorization"),
original_response=json.dumps(resp.json()), original_response=json.dumps(resp.json()),
additional_args={ additional_args={
"status_code": resp.status_code, "status_code": resp.status_code,
"complete_input_dict": request, "complete_input_dict": request_params['json'],
}, },
) )