Merge branch 'main' into feature/watsonx-integration

This commit is contained in:
Simon Sanchez Viloria 2024-05-10 12:09:09 +02:00
commit e1372de9ee
23 changed files with 8026 additions and 271 deletions

View file

@ -451,9 +451,6 @@ class IBMWatsonXAI(BaseLLM):
return streamwrapper
# create the function to manage the request to watsonx.ai
# manage_request = self._make_request_manager(
# async_=(acompletion is True), logging_obj=logging_obj
# )
self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse:
@ -576,9 +573,6 @@ class IBMWatsonXAI(BaseLLM):
"json": payload,
"params": request_params,
}
# manage_request = self._make_request_manager(
# async_=(aembedding is True), logging_obj=logging_obj
# )
request_manager = RequestManager(logging_obj)
def process_embedding_response(json_resp: dict) -> ModelResponse:
@ -654,143 +648,12 @@ class IBMWatsonXAI(BaseLLM):
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
with RequestManager(logging_obj=None).request(req_params) as resp:
json_resp = resp.json()
if not ids_only:
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
def _make_request_manager(
self, async_: bool, logging_obj=None
) -> Callable[
...,
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
async with manage_request(request_params) as resp:
...
# or
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
with manage_request(request_params) as resp:
...
```
"""
def pre_call(
request_params: dict,
input: Optional[Any] = None,
):
if logging_obj is None:
return
request_str = (
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(resp, request_params):
if logging_obj is None:
return
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def _manage_request(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
@asynccontextmanager
async def _manage_request_async(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_request_async
else:
return _manage_request
class RequestManager:
"""
Returns a context manager that manages the response from the request.