mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error
167 lines
6 KiB
Python
167 lines
6 KiB
Python
import json
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
AsyncHTTPHandler,
|
|
HTTPHandler,
|
|
_get_httpx_client,
|
|
get_async_httpx_client,
|
|
)
|
|
from litellm.types.llms.bedrock import BedrockPreparedRequest
|
|
from litellm.types.rerank import RerankRequest
|
|
from litellm.types.utils import RerankResponse
|
|
|
|
from ..base_aws_llm import BaseAWSLLM
|
|
from ..common_utils import BedrockError
|
|
from .transformation import BedrockRerankConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from botocore.awsrequest import AWSPreparedRequest
|
|
else:
|
|
AWSPreparedRequest = Any
|
|
|
|
|
|
class BedrockRerankHandler(BaseAWSLLM):
|
|
async def arerank(
|
|
self,
|
|
prepared_request: BedrockPreparedRequest,
|
|
client: Optional[AsyncHTTPHandler] = None,
|
|
):
|
|
if client is None:
|
|
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
|
try:
|
|
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as err:
|
|
error_code = err.response.status_code
|
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
|
except httpx.TimeoutException:
|
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
|
|
|
return BedrockRerankConfig()._transform_response(response.json())
|
|
|
|
def rerank(
|
|
self,
|
|
model: str,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
optional_params: dict,
|
|
logging_obj: LitellmLogging,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = True,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
_is_async: Optional[bool] = False,
|
|
api_base: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
) -> RerankResponse:
|
|
request_data = RerankRequest(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
)
|
|
data = BedrockRerankConfig()._transform_request(request_data)
|
|
|
|
prepared_request = self._prepare_request(
|
|
model=model,
|
|
optional_params=optional_params,
|
|
api_base=api_base,
|
|
extra_headers=extra_headers,
|
|
data=cast(dict, data),
|
|
)
|
|
|
|
logging_obj.pre_call(
|
|
input=data,
|
|
api_key="",
|
|
additional_args={
|
|
"complete_input_dict": data,
|
|
"api_base": prepared_request["endpoint_url"],
|
|
"headers": prepared_request["prepped"].headers,
|
|
},
|
|
)
|
|
|
|
if _is_async:
|
|
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
|
|
|
if client is None or not isinstance(client, HTTPHandler):
|
|
client = _get_httpx_client()
|
|
try:
|
|
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as err:
|
|
error_code = err.response.status_code
|
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
|
except httpx.TimeoutException:
|
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
|
|
|
logging_obj.post_call(
|
|
original_response=response.text,
|
|
api_key="",
|
|
)
|
|
|
|
response_json = response.json()
|
|
|
|
return BedrockRerankConfig()._transform_response(response_json)
|
|
|
|
def _prepare_request(
|
|
self,
|
|
model: str,
|
|
api_base: Optional[str],
|
|
extra_headers: Optional[dict],
|
|
data: dict,
|
|
optional_params: dict,
|
|
) -> BedrockPreparedRequest:
|
|
try:
|
|
from botocore.auth import SigV4Auth
|
|
from botocore.awsrequest import AWSRequest
|
|
except ImportError:
|
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
|
optional_params, model
|
|
)
|
|
|
|
### SET RUNTIME ENDPOINT ###
|
|
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
|
api_base=api_base,
|
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
|
aws_region_name=boto3_credentials_info.aws_region_name,
|
|
)
|
|
proxy_endpoint_url = proxy_endpoint_url.replace(
|
|
"bedrock-runtime", "bedrock-agent-runtime"
|
|
)
|
|
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
|
|
sigv4 = SigV4Auth(
|
|
boto3_credentials_info.credentials,
|
|
"bedrock",
|
|
boto3_credentials_info.aws_region_name,
|
|
)
|
|
# Make POST Request
|
|
body = json.dumps(data).encode("utf-8")
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if extra_headers is not None:
|
|
headers = {"Content-Type": "application/json", **extra_headers}
|
|
request = AWSRequest(
|
|
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
|
)
|
|
sigv4.add_auth(request)
|
|
if (
|
|
extra_headers is not None and "Authorization" in extra_headers
|
|
): # prevent sigv4 from overwriting the auth header
|
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
|
prepped = request.prepare()
|
|
|
|
return BedrockPreparedRequest(
|
|
endpoint_url=proxy_endpoint_url,
|
|
prepped=prepped,
|
|
body=body,
|
|
data=data,
|
|
)
|