forked from phoenix/litellm-mirror
(feat) make manage_response work with request.request instead of httpx.Request
This commit is contained in:
parent
9fc30e8b31
commit
777b4b2bbc
1 changed files with 52 additions and 49 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import Enum
|
||||||
import json, types, time # noqa: E401
|
import json, types, time # noqa: E401
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, Dict, Optional, Any, Union, List
|
from typing import Callable, Dict, Optional, Any, Union, List
|
||||||
|
@ -160,6 +161,15 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
class WatsonXAIEndpoint(str, Enum):
|
||||||
|
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||||
|
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||||
|
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||||
|
DEPLOYMENT_TEXT_GENERATION_STREAM = (
|
||||||
|
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||||
|
)
|
||||||
|
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||||
|
PROMPTS = "/ml/v1/prompts"
|
||||||
|
|
||||||
class IBMWatsonXAI(BaseLLM):
|
class IBMWatsonXAI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
|
@ -169,14 +179,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_version = "2024-03-13"
|
api_version = "2024-03-13"
|
||||||
_text_gen_endpoint = "/ml/v1/text/generation"
|
|
||||||
_text_gen_stream_endpoint = "/ml/v1/text/generation_stream"
|
|
||||||
_deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation"
|
|
||||||
_deployment_text_gen_stream_endpoint = (
|
|
||||||
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
|
||||||
)
|
|
||||||
_embeddings_endpoint = "/ml/v1/text/embeddings"
|
|
||||||
_prompts_endpoint = "/ml/v1/prompts"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -188,7 +191,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
stream: bool,
|
stream: bool,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
print_verbose: Callable = None,
|
print_verbose: Callable = None,
|
||||||
) -> httpx.Request:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Get the request parameters for text generation.
|
Get the request parameters for text generation.
|
||||||
"""
|
"""
|
||||||
|
@ -221,20 +224,23 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
)
|
)
|
||||||
deployment_id = "/".join(model_id.split("/")[1:])
|
deployment_id = "/".join(model_id.split("/")[1:])
|
||||||
endpoint = (
|
endpoint = (
|
||||||
self._deployment_text_gen_stream_endpoint
|
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM
|
||||||
if stream
|
if stream
|
||||||
else self._deployment_text_gen_endpoint
|
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION
|
||||||
)
|
)
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
else:
|
else:
|
||||||
payload["model_id"] = model_id
|
payload["model_id"] = model_id
|
||||||
payload["project_id"] = api_params["project_id"]
|
payload["project_id"] = api_params["project_id"]
|
||||||
endpoint = (
|
endpoint = (
|
||||||
self._text_gen_stream_endpoint if stream else self._text_gen_endpoint
|
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||||
)
|
)
|
||||||
url = api_params["url"].rstrip("/") + endpoint
|
url = api_params["url"].rstrip("/") + endpoint
|
||||||
return httpx.Request(
|
return dict(
|
||||||
"POST", url, headers=headers, json=payload, params=request_params
|
method="POST", url=url, headers=headers,
|
||||||
|
json=payload, params=request_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
|
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
|
||||||
|
@ -360,9 +366,9 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
model, messages, provider, custom_prompt_dict
|
model, messages, provider, custom_prompt_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_text_request(request: httpx.Request) -> ModelResponse:
|
def process_text_request(request_params: dict) -> ModelResponse:
|
||||||
with self._manage_response(
|
with self._manage_response(
|
||||||
request, logging_obj=logging_obj, input=prompt, timeout=timeout
|
request_params, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||||
) as resp:
|
) as resp:
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
|
|
||||||
|
@ -381,12 +387,12 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def process_stream_request(
|
def process_stream_request(
|
||||||
request: httpx.Request,
|
request_params: dict,
|
||||||
) -> litellm.CustomStreamWrapper:
|
) -> litellm.CustomStreamWrapper:
|
||||||
# stream the response - generated chunks will be handled
|
# stream the response - generated chunks will be handled
|
||||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||||
with self._manage_response(
|
with self._manage_response(
|
||||||
request,
|
request_params,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
stream=True,
|
stream=True,
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -402,7 +408,7 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
## Get the response from the model
|
## Get the response from the model
|
||||||
request = self._prepare_text_generation_req(
|
req_params = self._prepare_text_generation_req(
|
||||||
model_id=model,
|
model_id=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -410,9 +416,9 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return process_stream_request(request)
|
return process_stream_request(req_params)
|
||||||
else:
|
else:
|
||||||
return process_text_request(request)
|
return process_text_request(req_params)
|
||||||
except WatsonXAIError as e:
|
except WatsonXAIError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -460,12 +466,19 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
}
|
}
|
||||||
request_params = dict(version=api_params["api_version"])
|
request_params = dict(version=api_params["api_version"])
|
||||||
url = api_params["url"].rstrip("/") + self._embeddings_endpoint
|
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||||
request = httpx.Request(
|
# request = httpx.Request(
|
||||||
"POST", url, headers=headers, json=payload, params=request_params
|
# "POST", url, headers=headers, json=payload, params=request_params
|
||||||
)
|
# )
|
||||||
|
req_params = {
|
||||||
|
"method": "POST",
|
||||||
|
"url": url,
|
||||||
|
"headers": headers,
|
||||||
|
"json": payload,
|
||||||
|
"params": request_params,
|
||||||
|
}
|
||||||
with self._manage_response(
|
with self._manage_response(
|
||||||
request, logging_obj=logging_obj, input=input
|
req_params, logging_obj=logging_obj, input=input
|
||||||
) as resp:
|
) as resp:
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
|
|
||||||
|
@ -508,48 +521,38 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _manage_response(
|
def _manage_response(
|
||||||
self,
|
self,
|
||||||
request: httpx.Request,
|
request_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
input: Optional[Any] = None,
|
input: Optional[Any] = None,
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
):
|
):
|
||||||
request_str = (
|
request_str = (
|
||||||
f"response = {request.method}(\n"
|
f"response = {request_params['method']}(\n"
|
||||||
f"\turl={request.url},\n"
|
f"\turl={request_params['url']},\n"
|
||||||
f"\tjson={request.content.decode()},\n"
|
f"\tjson={request_params['json']},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
json_input = json.loads(request.content.decode())
|
|
||||||
headers = dict(request.headers)
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request.headers.get("Authorization"),
|
api_key=request_params['headers'].get("Authorization"),
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": json_input,
|
"complete_input_dict": request_params['json'],
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if timeout:
|
||||||
|
request_params['timeout'] = timeout
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
resp = requests.request(
|
resp = requests.request(
|
||||||
method=request.method,
|
**request_params,
|
||||||
url=str(request.url),
|
|
||||||
headers=headers,
|
|
||||||
json=json_input,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=timeout,
|
|
||||||
)
|
)
|
||||||
# resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yield resp
|
yield resp
|
||||||
else:
|
else:
|
||||||
resp = requests.request(
|
resp = requests.request(**request_params)
|
||||||
method=request.method,
|
|
||||||
url=str(request.url),
|
|
||||||
headers=headers,
|
|
||||||
json=json_input,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yield resp
|
yield resp
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -557,10 +560,10 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
if not stream:
|
if not stream:
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request.headers.get("Authorization"),
|
api_key=request_params['headers'].get("Authorization"),
|
||||||
original_response=json.dumps(resp.json()),
|
original_response=json.dumps(resp.json()),
|
||||||
additional_args={
|
additional_args={
|
||||||
"status_code": resp.status_code,
|
"status_code": resp.status_code,
|
||||||
"complete_input_dict": request,
|
"complete_input_dict": request_params['json'],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue