mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Merge branch 'main' into feature/watsonx-integration
This commit is contained in:
commit
e1372de9ee
23 changed files with 8026 additions and 271 deletions
|
@ -451,9 +451,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
return streamwrapper
|
||||
|
||||
# create the function to manage the request to watsonx.ai
|
||||
# manage_request = self._make_request_manager(
|
||||
# async_=(acompletion is True), logging_obj=logging_obj
|
||||
# )
|
||||
self.request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
|
@ -576,9 +573,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
# manage_request = self._make_request_manager(
|
||||
# async_=(aembedding is True), logging_obj=logging_obj
|
||||
# )
|
||||
request_manager = RequestManager(logging_obj)
|
||||
|
||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||
|
@ -654,143 +648,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
|
||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||
json_resp = resp.json()
|
||||
if not ids_only:
|
||||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
def _make_request_manager(
|
||||
self, async_: bool, logging_obj=None
|
||||
) -> Callable[
|
||||
...,
|
||||
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
|
||||
]:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
|
||||
async with manage_request(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
|
||||
with manage_request(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def pre_call(
|
||||
request_params: dict,
|
||||
input: Optional[Any] = None,
|
||||
):
|
||||
if logging_obj is None:
|
||||
return
|
||||
request_str = (
|
||||
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params.get('json')},\n"
|
||||
f")"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params.get("json"),
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
def post_call(resp, request_params):
|
||||
if logging_obj is None:
|
||||
return
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params.get(
|
||||
"data", request_params.get("json")
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _manage_request(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
resp = requests.request(**request_params)
|
||||
if not resp.ok:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _manage_request_async(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(
|
||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||
),
|
||||
)
|
||||
# async_handler.client.verify = False
|
||||
if "json" in request_params:
|
||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp.status_code not in [200, 201]:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
if async_:
|
||||
return _manage_request_async
|
||||
else:
|
||||
return _manage_request
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue