forked from phoenix/litellm-mirror
(fix) watsonx.py: Fixed linting errors and make sure stream chunk always return usage
This commit is contained in:
parent
66a1b581e5
commit
170fd11c82
2 changed files with 279 additions and 92 deletions
|
@ -1,12 +1,25 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json, types, time # noqa: E401
|
import json, types, time # noqa: E401
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
AsyncGenerator,
|
||||||
|
Iterator,
|
||||||
|
AsyncIterator,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
Union,
|
||||||
|
List,
|
||||||
|
ContextManager,
|
||||||
|
AsyncContextManager,
|
||||||
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import Logging, ModelResponse, Usage, get_secret
|
from litellm.utils import ModelResponse, Usage, get_secret
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -189,6 +202,7 @@ class WatsonXAIEndpoint(str, Enum):
|
||||||
)
|
)
|
||||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||||
PROMPTS = "/ml/v1/prompts"
|
PROMPTS = "/ml/v1/prompts"
|
||||||
|
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXAI(BaseLLM):
|
class IBMWatsonXAI(BaseLLM):
|
||||||
|
@ -378,12 +392,12 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj: Logging,
|
logging_obj,
|
||||||
optional_params: Optional[dict] = None,
|
optional_params=None,
|
||||||
acompletion: bool = None,
|
acompletion=None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout: Optional[float] = None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send a text generation request to the IBM Watsonx.ai API.
|
Send a text generation request to the IBM Watsonx.ai API.
|
||||||
|
@ -403,8 +417,6 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
prompt = convert_messages_to_prompt(
|
prompt = convert_messages_to_prompt(
|
||||||
model, messages, provider, custom_prompt_dict
|
model, messages, provider, custom_prompt_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj)
|
|
||||||
|
|
||||||
def process_text_gen_response(json_resp: dict) -> ModelResponse:
|
def process_text_gen_response(json_resp: dict) -> ModelResponse:
|
||||||
if "results" not in json_resp:
|
if "results" not in json_resp:
|
||||||
|
@ -419,62 +431,72 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
setattr(
|
usage = Usage(
|
||||||
model_response,
|
prompt_tokens=prompt_tokens,
|
||||||
"usage",
|
completion_tokens=completion_tokens,
|
||||||
Usage(
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def process_stream_response(
|
||||||
|
stream_resp: Union[Iterator[str], AsyncIterator],
|
||||||
|
) -> litellm.CustomStreamWrapper:
|
||||||
|
streamwrapper = litellm.CustomStreamWrapper(
|
||||||
|
stream_resp,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="watsonx",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
# create the function to manage the request to watsonx.ai
|
||||||
|
# manage_request = self._make_request_manager(
|
||||||
|
# async_=(acompletion is True), logging_obj=logging_obj
|
||||||
|
# )
|
||||||
|
self.request_manager = RequestManager(logging_obj)
|
||||||
|
|
||||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||||
with manage_response(
|
with self.request_manager.request(
|
||||||
request_params, input=prompt, timeout=timeout,
|
request_params,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
) as resp:
|
) as resp:
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
|
|
||||||
return process_text_gen_response(json_resp)
|
return process_text_gen_response(json_resp)
|
||||||
|
|
||||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
||||||
async with manage_response(
|
async with self.request_manager.async_request(
|
||||||
request_params, input=prompt, timeout=timeout,
|
request_params,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
) as resp:
|
) as resp:
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
return process_text_gen_response(json_resp)
|
return process_text_gen_response(json_resp)
|
||||||
|
|
||||||
def handle_stream_request(
|
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||||
request_params: dict,
|
|
||||||
) -> litellm.CustomStreamWrapper:
|
|
||||||
# stream the response - generated chunks will be handled
|
# stream the response - generated chunks will be handled
|
||||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||||
with manage_response(
|
with self.request_manager.request(
|
||||||
request_params, stream=True, input=prompt, timeout=timeout,
|
request_params,
|
||||||
|
stream=True,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
) as resp:
|
) as resp:
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
streamwrapper = process_stream_response(resp.iter_lines())
|
||||||
resp.iter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="watsonx",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
async def handle_stream_request_async(
|
async def handle_stream_request_async(request_params: dict) -> litellm.CustomStreamWrapper:
|
||||||
request_params: dict,
|
|
||||||
) -> litellm.CustomStreamWrapper:
|
|
||||||
# stream the response - generated chunks will be handled
|
# stream the response - generated chunks will be handled
|
||||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||||
async with manage_response(
|
async with self.request_manager.async_request(
|
||||||
request_params, stream=True, input=prompt, timeout=timeout,
|
request_params,
|
||||||
|
stream=True,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
) as resp:
|
) as resp:
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
streamwrapper = process_stream_response(resp.aiter_lines())
|
||||||
resp.aiter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="watsonx",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -486,13 +508,13 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if stream and acompletion:
|
if stream and (acompletion is True):
|
||||||
# stream and async text generation
|
# stream and async text generation
|
||||||
return handle_stream_request_async(req_params)
|
return handle_stream_request_async(req_params)
|
||||||
elif stream:
|
elif stream:
|
||||||
# streaming text generation
|
# streaming text generation
|
||||||
return handle_stream_request(req_params)
|
return handle_stream_request(req_params)
|
||||||
elif acompletion:
|
elif (acompletion is True):
|
||||||
# async text generation
|
# async text generation
|
||||||
return handle_text_request_async(req_params)
|
return handle_text_request_async(req_params)
|
||||||
else:
|
else:
|
||||||
|
@ -554,43 +576,48 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
"json": payload,
|
"json": payload,
|
||||||
"params": request_params,
|
"params": request_params,
|
||||||
}
|
}
|
||||||
manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj)
|
# manage_request = self._make_request_manager(
|
||||||
|
# async_=(aembedding is True), logging_obj=logging_obj
|
||||||
|
# )
|
||||||
|
request_manager = RequestManager(logging_obj)
|
||||||
|
|
||||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||||
results = json_resp.get("results", [])
|
results = json_resp.get("results", [])
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
for idx, result in enumerate(results):
|
for idx, result in enumerate(results):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": result["embedding"],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = json_resp.get("input_token_count", 0)
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
model_response.usage = Usage(
|
model_response.usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=input_tokens,
|
||||||
)
|
)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def handle_embedding_request(request_params: dict) -> ModelResponse:
|
def handle_embedding(request_params: dict) -> ModelResponse:
|
||||||
with manage_response(
|
with request_manager.request(request_params, input=input) as resp:
|
||||||
request_params, input=input
|
|
||||||
) as resp:
|
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
return process_embedding_response(json_resp)
|
return process_embedding_response(json_resp)
|
||||||
|
|
||||||
async def handle_embedding_request_async(request_params: dict) -> ModelResponse:
|
async def handle_aembedding(request_params: dict) -> ModelResponse:
|
||||||
async with manage_response(
|
async with request_manager.async_request(request_params, input=input) as resp:
|
||||||
request_params, input=input
|
|
||||||
) as resp:
|
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
return process_embedding_response(json_resp)
|
return process_embedding_response(json_resp)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if aembedding:
|
if aembedding is True:
|
||||||
return handle_embedding_request_async(req_params)
|
return handle_embedding(req_params)
|
||||||
else:
|
else:
|
||||||
return handle_embedding_request(req_params)
|
return handle_aembedding(req_params)
|
||||||
except WatsonXAIError as e:
|
except WatsonXAIError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -616,64 +643,88 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
iam_access_token = json_data["access_token"]
|
iam_access_token = json_data["access_token"]
|
||||||
self.token = iam_access_token
|
self.token = iam_access_token
|
||||||
return iam_access_token
|
return iam_access_token
|
||||||
|
|
||||||
def _make_response_manager(
|
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||||
self,
|
api_params = self._get_api_params(params)
|
||||||
async_: bool,
|
headers = {
|
||||||
logging_obj: Logging
|
"Authorization": f"Bearer {api_params['token']}",
|
||||||
) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]:
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
request_params = dict(version=api_params["api_version"])
|
||||||
|
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||||
|
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||||
|
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
|
||||||
|
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||||
|
json_resp = resp.json()
|
||||||
|
if not ids_only:
|
||||||
|
return json_resp
|
||||||
|
return [res["model_id"] for res in json_resp["resources"]]
|
||||||
|
|
||||||
|
def _make_request_manager(
|
||||||
|
self, async_: bool, logging_obj=None
|
||||||
|
) -> Callable[
|
||||||
|
...,
|
||||||
|
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Returns a context manager that manages the response from the request.
|
Returns a context manager that manages the response from the request.
|
||||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```python
|
```python
|
||||||
manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj)
|
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
|
||||||
async with manage_response(request_params) as resp:
|
async with manage_request(request_params) as resp:
|
||||||
...
|
...
|
||||||
# or
|
# or
|
||||||
manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj)
|
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
|
||||||
with manage_response(request_params) as resp:
|
with manage_request(request_params) as resp:
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_call(
|
def pre_call(
|
||||||
request_params: dict,
|
request_params: dict,
|
||||||
input:Optional[Any]=None,
|
input: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
|
if logging_obj is None:
|
||||||
|
return
|
||||||
request_str = (
|
request_str = (
|
||||||
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
||||||
f"\turl={request_params['url']},\n"
|
f"\turl={request_params['url']},\n"
|
||||||
f"\tjson={request_params['json']},\n"
|
f"\tjson={request_params.get('json')},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": request_params["json"],
|
"complete_input_dict": request_params.get("json"),
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_call(resp, request_params):
|
def post_call(resp, request_params):
|
||||||
|
if logging_obj is None:
|
||||||
|
return
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
original_response=json.dumps(resp.json()),
|
original_response=json.dumps(resp.json()),
|
||||||
additional_args={
|
additional_args={
|
||||||
"status_code": resp.status_code,
|
"status_code": resp.status_code,
|
||||||
"complete_input_dict": request_params.get("data", request_params.get("json")),
|
"complete_input_dict": request_params.get(
|
||||||
|
"data", request_params.get("json")
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _manage_response(
|
def _manage_request(
|
||||||
request_params: dict,
|
request_params: dict,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
input: Optional[Any] = None,
|
input: Optional[Any] = None,
|
||||||
timeout: float = None,
|
timeout=None,
|
||||||
) -> Generator[requests.Response, None, None]:
|
) -> Generator[requests.Response, None, None]:
|
||||||
"""
|
"""
|
||||||
Returns a context manager that yields the response from the request.
|
Returns a context manager that yields the response from the request.
|
||||||
|
@ -685,20 +736,23 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
request_params["stream"] = stream
|
request_params["stream"] = stream
|
||||||
try:
|
try:
|
||||||
resp = requests.request(**request_params)
|
resp = requests.request(**request_params)
|
||||||
resp.raise_for_status()
|
if not resp.ok:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=resp.status_code,
|
||||||
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
|
)
|
||||||
yield resp
|
yield resp
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
if not stream:
|
if not stream:
|
||||||
post_call(resp, request_params)
|
post_call(resp, request_params)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _manage_response_async(
|
async def _manage_request_async(
|
||||||
request_params: dict,
|
request_params: dict,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
input: Optional[Any] = None,
|
input: Optional[Any] = None,
|
||||||
timeout: float = None,
|
timeout=None,
|
||||||
) -> AsyncGenerator[httpx.Response, None]:
|
) -> AsyncGenerator[httpx.Response, None]:
|
||||||
pre_call(request_params, input)
|
pre_call(request_params, input)
|
||||||
if timeout:
|
if timeout:
|
||||||
|
@ -708,16 +762,23 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
try:
|
try:
|
||||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0),
|
timeout=httpx.Timeout(
|
||||||
|
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# async_handler.client.verify = False
|
# async_handler.client.verify = False
|
||||||
if "json" in request_params:
|
if "json" in request_params:
|
||||||
request_params['data'] = json.dumps(request_params.pop("json", {}))
|
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||||
method = request_params.pop("method")
|
method = request_params.pop("method")
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
resp = await self.async_handler.post(**request_params)
|
resp = await self.async_handler.post(**request_params)
|
||||||
else:
|
else:
|
||||||
resp = await self.async_handler.get(**request_params)
|
resp = await self.async_handler.get(**request_params)
|
||||||
|
if resp.status_code not in [200, 201]:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=resp.status_code,
|
||||||
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
|
)
|
||||||
yield resp
|
yield resp
|
||||||
# await async_handler.close()
|
# await async_handler.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -726,6 +787,132 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
post_call(resp, request_params)
|
post_call(resp, request_params)
|
||||||
|
|
||||||
if async_:
|
if async_:
|
||||||
return _manage_response_async
|
return _manage_request_async
|
||||||
else:
|
else:
|
||||||
return _manage_response
|
return _manage_request
|
||||||
|
|
||||||
|
class RequestManager:
|
||||||
|
"""
|
||||||
|
Returns a context manager that manages the response from the request.
|
||||||
|
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
||||||
|
request_manager = RequestManager(logging_obj=logging_obj)
|
||||||
|
async with request_manager.request(request_params) as resp:
|
||||||
|
...
|
||||||
|
# or
|
||||||
|
with request_manager.async_request(request_params) as resp:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logging_obj=None):
|
||||||
|
self.logging_obj = logging_obj
|
||||||
|
|
||||||
|
def pre_call(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
if self.logging_obj is None:
|
||||||
|
return
|
||||||
|
request_str = (
|
||||||
|
f"response = {request_params['method']}(\n"
|
||||||
|
f"\turl={request_params['url']},\n"
|
||||||
|
f"\tjson={request_params.get('json')},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
self.logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_params.get("json"),
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_call(self, resp, request_params):
|
||||||
|
if self.logging_obj is None:
|
||||||
|
return
|
||||||
|
self.logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request_params["headers"].get("Authorization"),
|
||||||
|
original_response=json.dumps(resp.json()),
|
||||||
|
additional_args={
|
||||||
|
"status_code": resp.status_code,
|
||||||
|
"complete_input_dict": request_params.get(
|
||||||
|
"data", request_params.get("json")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout=None,
|
||||||
|
) -> Generator[requests.Response, None, None]:
|
||||||
|
"""
|
||||||
|
Returns a context manager that yields the response from the request.
|
||||||
|
"""
|
||||||
|
self.pre_call(request_params, input)
|
||||||
|
if timeout:
|
||||||
|
request_params["timeout"] = timeout
|
||||||
|
if stream:
|
||||||
|
request_params["stream"] = stream
|
||||||
|
try:
|
||||||
|
resp = requests.request(**request_params)
|
||||||
|
if not resp.ok:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=resp.status_code,
|
||||||
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
|
)
|
||||||
|
yield resp
|
||||||
|
except Exception as e:
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
if not stream:
|
||||||
|
self.post_call(resp, request_params)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_request(
|
||||||
|
self,
|
||||||
|
request_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout=None,
|
||||||
|
) -> AsyncGenerator[httpx.Response, None]:
|
||||||
|
self.pre_call(request_params, input)
|
||||||
|
if timeout:
|
||||||
|
request_params["timeout"] = timeout
|
||||||
|
if stream:
|
||||||
|
request_params["stream"] = stream
|
||||||
|
try:
|
||||||
|
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# async_handler.client.verify = False
|
||||||
|
if "json" in request_params:
|
||||||
|
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||||
|
method = request_params.pop("method")
|
||||||
|
if method.upper() == "POST":
|
||||||
|
resp = await self.async_handler.post(**request_params)
|
||||||
|
else:
|
||||||
|
resp = await self.async_handler.get(**request_params)
|
||||||
|
if resp.status_code not in [200, 201]:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=resp.status_code,
|
||||||
|
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||||
|
)
|
||||||
|
yield resp
|
||||||
|
# await async_handler.close()
|
||||||
|
except Exception as e:
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
if not stream:
|
||||||
|
self.post_call(resp, request_params)
|
|
@ -10285,7 +10285,7 @@ class CustomStreamWrapper:
|
||||||
response = chunk.replace("data: ", "").strip()
|
response = chunk.replace("data: ", "").strip()
|
||||||
parsed_response = json.loads(response)
|
parsed_response = json.loads(response)
|
||||||
else:
|
else:
|
||||||
return {"text": "", "is_finished": False}
|
return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0}
|
||||||
else:
|
else:
|
||||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -10300,8 +10300,8 @@ class CustomStreamWrapper:
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"prompt_tokens": results[0].get("input_token_count", None),
|
"prompt_tokens": results[0].get("input_token_count", 0),
|
||||||
"completion_tokens": results[0].get("generated_token_count", None),
|
"completion_tokens": results[0].get("generated_token_count", 0),
|
||||||
}
|
}
|
||||||
return {"text": "", "is_finished": False}
|
return {"text": "", "is_finished": False}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue