mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
LiteLLM Minor Fixes & Improvements (12/05/2024) (#7037)
* fix(together_ai/chat): only return response_format + tools for supported models Fixes https://github.com/BerriAI/litellm/issues/6972 * feat(bedrock/rerank): initial working commit for bedrock rerank api support Closes https://github.com/BerriAI/litellm/issues/7021 * feat(bedrock/rerank): async bedrock rerank api support Addresses https://github.com/BerriAI/litellm/issues/7021 * build(model_prices_and_context_window.json): add 'supports_prompt_caching' for bedrock models + cleanup cross-region from model list (duplicate information - lead to inconsistencies ) * docs(json_mode.md): clarify model support for json schema Closes https://github.com/BerriAI/litellm/issues/6998 * fix(_service_logger.py): handle dd callback in list ensure failed spend tracking is logged to datadog * feat(converse_transformation.py): translate from anthropic format to bedrock format Closes https://github.com/BerriAI/litellm/issues/7030 * fix: fix linting errors * test: fix test
This commit is contained in:
parent
12dfd14b52
commit
61b35c12bb
24 changed files with 858 additions and 400 deletions
159
litellm/llms/bedrock/rerank/handler.py
Normal file
159
litellm/llms/bedrock/rerank/handler.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import BedrockPreparedRequest, BedrockRerankRequest
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ...base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .transformation import BedrockRerankConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockRerankHandler(BaseAWSLLM):
|
||||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
):
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> RerankResponse:
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
data=cast(dict, data),
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepared_request["endpoint_url"],
|
||||
"headers": prepared_request["prepped"].headers,
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request) # type: ignore
|
||||
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
) -> BedrockPreparedRequest:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = proxy_endpoint_url.replace(
|
||||
"bedrock-runtime", "bedrock-agent-runtime"
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return BedrockPreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue