LiteLLM Minor Fixes & Improvements (12/05/2024) (#7037)

* fix(together_ai/chat): only return response_format + tools for supported models

Fixes https://github.com/BerriAI/litellm/issues/6972

* feat(bedrock/rerank): initial working commit for bedrock rerank api support

Closes https://github.com/BerriAI/litellm/issues/7021

* feat(bedrock/rerank): async bedrock rerank api support

Addresses https://github.com/BerriAI/litellm/issues/7021

* build(model_prices_and_context_window.json): add 'supports_prompt_caching' for bedrock models + cleanup cross-region from model list (duplicate information - lead to inconsistencies )

* docs(json_mode.md): clarify model support for json schema

Closes https://github.com/BerriAI/litellm/issues/6998

* fix(_service_logger.py): handle dd callback in list

ensure failed spend tracking is logged to datadog

* feat(converse_transformation.py): translate from anthropic format to bedrock format

Closes https://github.com/BerriAI/litellm/issues/7030

* fix: fix linting errors

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-05 00:02:31 -08:00 committed by GitHub
parent 12dfd14b52
commit 61b35c12bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 858 additions and 400 deletions

View file

@ -51,6 +51,9 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Check Model Support ## Check Model Support
### 1. Check if model supports `response_format`
Call `litellm.get_supported_openai_params` to check if a model/provider supports `response_format`. Call `litellm.get_supported_openai_params` to check if a model/provider supports `response_format`.
```python ```python
@ -61,6 +64,20 @@ params = get_supported_openai_params(model="anthropic.claude-3", custom_llm_prov
assert "response_format" in params assert "response_format" in params
``` ```
### 2. Check if model supports `json_schema`
This is used to check if you can pass
- `response_format={ "type": "json_schema", "json_schema": … , "strict": true }`
- `response_format=<Pydantic Model>`
```python
from litellm import supports_response_schema
assert supports_response_schema(model="gemini-1.5-pro-preview-0215", custom_llm_provider="bedrock")
```
Check out [model_prices_and_context_window.json](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) for a full list of models and their support for `response_schema`.
## Pass in 'json_schema' ## Pass in 'json_schema'
To use Structured Outputs, simply specify To use Structured Outputs, simply specify

View file

@ -7,6 +7,7 @@ from litellm._logging import verbose_logger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from .integrations.datadog.datadog import DataDogLogger
from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.prometheus_services import PrometheusServicesLogger
from .types.services import ServiceLoggerPayload, ServiceTypes from .types.services import ServiceLoggerPayload, ServiceTypes
@ -134,9 +135,7 @@ class ServiceLogging(CustomLogger):
await self.prometheusServicesLogger.async_service_success_hook( await self.prometheusServicesLogger.async_service_success_hook(
payload=payload payload=payload
) )
elif callback == "datadog": elif callback == "datadog" or isinstance(callback, DataDogLogger):
from litellm.integrations.datadog.datadog import DataDogLogger
await self.init_datadog_logger_if_none() await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_success_hook( await self.dd_logger.async_service_success_hook(
payload=payload, payload=payload,
@ -237,6 +236,7 @@ class ServiceLogging(CustomLogger):
duration=duration, duration=duration,
call_type=call_type, call_type=call_type,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
if callback == "prometheus_system": if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none() await self.init_prometheus_services_logger_if_none()
@ -244,7 +244,7 @@ class ServiceLogging(CustomLogger):
payload=payload, payload=payload,
error=error, error=error,
) )
elif callback == "datadog": elif callback == "datadog" or isinstance(callback, DataDogLogger):
await self.init_datadog_logger_if_none() await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_failure_hook( await self.dd_logger.async_service_failure_hook(
payload=payload, payload=payload,

View file

@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
import copy import copy
import time import time
import types import types
from typing import List, Optional, Union from typing import List, Literal, Optional, Tuple, Union, cast, overload
import httpx import httpx
@ -255,6 +255,59 @@ class AmazonConverseConfig:
) )
return optional_params return optional_params
@overload
def _get_cache_point_block(
self, message_block: dict, block_type: Literal["system"]
) -> Optional[SystemContentBlock]:
pass
@overload
def _get_cache_point_block(
self, message_block: dict, block_type: Literal["content_block"]
) -> Optional[ContentBlock]:
pass
def _get_cache_point_block(
self, message_block: dict, block_type: Literal["system", "content_block"]
) -> Optional[Union[SystemContentBlock, ContentBlock]]:
if message_block.get("cache_control", None) is None:
return None
if block_type == "system":
return SystemContentBlock(cachePoint=CachePointBlock(type="default"))
else:
return ContentBlock(cachePoint=CachePointBlock(type="default"))
def _transform_system_message(
self, messages: List[AllMessageValues]
) -> Tuple[List[AllMessageValues], List[SystemContentBlock]]:
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block: Optional[SystemContentBlock] = None
_cache_point_block: Optional[SystemContentBlock] = None
if isinstance(message["content"], str) and len(message["content"]) > 0:
_system_content_block = SystemContentBlock(text=message["content"])
_cache_point_block = self._get_cache_point_block(
cast(dict, message), block_type="system"
)
elif isinstance(message["content"], list):
for m in message["content"]:
if m.get("type", "") == "text" and len(m["text"]) > 0:
_system_content_block = SystemContentBlock(text=m["text"])
_cache_point_block = self._get_cache_point_block(
m, block_type="system"
)
if _system_content_block is not None:
system_content_blocks.append(_system_content_block)
if _cache_point_block is not None:
system_content_blocks.append(_cache_point_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
return messages, system_content_blocks
def _transform_request( def _transform_request(
self, self,
model: str, model: str,
@ -262,24 +315,7 @@ class AmazonConverseConfig:
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,
) -> RequestObject: ) -> RequestObject:
system_prompt_indices = [] messages, system_content_blocks = self._transform_system_message(messages)
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block: Optional[SystemContentBlock] = None
if isinstance(message["content"], str) and len(message["content"]) > 0:
_system_content_block = SystemContentBlock(text=message["content"])
elif isinstance(message["content"], list):
for m in message["content"]:
if m.get("type", "") == "text" and len(m["text"]) > 0:
_system_content_block = SystemContentBlock(text=m["text"])
if _system_content_block is not None:
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
additional_request_keys = [] additional_request_keys = []
additional_request_params = {} additional_request_params = {}

View file

@ -0,0 +1,159 @@
import copy
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import httpx
from openai.types.image import Image
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import BedrockPreparedRequest, BedrockRerankRequest
from litellm.types.rerank import RerankRequest
from litellm.types.utils import RerankResponse
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .transformation import BedrockRerankConfig
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockRerankHandler(BaseAWSLLM):
async def arerank(
self,
prepared_request: BedrockPreparedRequest,
):
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try:
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json())
def rerank(
self,
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
optional_params: dict,
logging_obj: LitellmLogging,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
)
data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request(
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
data=cast(dict, data),
)
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepared_request["endpoint_url"],
"headers": prepared_request["prepped"].headers,
},
)
if _is_async:
return self.arerank(prepared_request) # type: ignore
client = _get_httpx_client()
try:
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json())
def _prepare_request(
self,
api_base: Optional[str],
extra_headers: Optional[dict],
data: dict,
optional_params: dict,
) -> BedrockPreparedRequest:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
### SET RUNTIME ENDPOINT ###
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = proxy_endpoint_url.replace(
"bedrock-runtime", "bedrock-agent-runtime"
)
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return BedrockPreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)

View file

@ -0,0 +1,117 @@
"""
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional, Union
from litellm.types.llms.bedrock import (
BedrockRerankBedrockRerankingConfiguration,
BedrockRerankConfiguration,
BedrockRerankInlineDocumentSource,
BedrockRerankModelConfiguration,
BedrockRerankQuery,
BedrockRerankRequest,
BedrockRerankSource,
BedrockRerankTextDocument,
BedrockRerankTextQuery,
)
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
class BedrockRerankConfig:
def _transform_sources(
self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]:
"""
Transform the sources from RerankRequest format to Bedrock format.
"""
_sources = []
for document in documents:
if isinstance(document, str):
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
textDocument=BedrockRerankTextDocument(text=document),
type="TEXT",
),
type="INLINE",
)
)
else:
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
jsonDocument=document, type="JSON"
),
type="INLINE",
)
)
return _sources
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
"""
Transform the request from RerankRequest format to Bedrock format.
"""
_sources = self._transform_sources(request_data.documents)
return BedrockRerankRequest(
queries=[
BedrockRerankQuery(
textQuery=BedrockRerankTextQuery(text=request_data.query),
type="TEXT",
)
],
rerankingConfiguration=BedrockRerankConfiguration(
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
modelConfiguration=BedrockRerankModelConfiguration(
modelArn=request_data.model
),
numberOfResults=request_data.top_n or len(request_data.documents),
),
type="BEDROCK_RERANKING_MODEL",
),
sources=_sources,
)
def _transform_response(self, response: dict) -> RerankResponse:
"""
Transform the response from Bedrock into the RerankResponse format.
example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
"""
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[RerankResponseResult]] = None
bedrock_results = response.get("results")
if bedrock_results:
_results = [
RerankResponseResult(
index=result.get("index"),
relevance_score=result.get("relevanceScore"),
)
for result in bedrock_results
]
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -31,6 +31,6 @@ class JinaAIRerankConfig:
return RerankResponse( return RerankResponse(
id=response.get("id") or str(uuid.uuid4()), id=response.get("id") or str(uuid.uuid4()),
results=_results, results=_results, # type: ignore
meta=rerank_meta, meta=rerank_meta,
) # Return response ) # Return response

View file

@ -2485,10 +2485,24 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
image_url=image_url image_url=image_url
) )
_parts.append(_part) # type: ignore _parts.append(_part) # type: ignore
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
element, block_type="content_block"
)
)
if _cache_point_block is not None:
_parts.append(_cache_point_block)
user_content.extend(_parts) user_content.extend(_parts)
else: else:
_part = BedrockContentBlock(text=messages[msg_i]["content"]) _part = BedrockContentBlock(text=messages[msg_i]["content"])
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
messages[msg_i], block_type="content_block"
)
)
user_content.append(_part) user_content.append(_part)
if _cache_point_block is not None:
user_content.append(_cache_point_block)
msg_i += 1 msg_i += 1
if user_content: if user_content:

View file

@ -6,8 +6,54 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
Docs: https://docs.together.ai/reference/completions-1 Docs: https://docs.together.ai/reference/completions-1
""" """
from typing import Optional
from litellm import get_model_info, verbose_logger
from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class TogetherAIConfig(OpenAIGPTConfig): class TogetherAIConfig(OpenAIGPTConfig):
pass def get_supported_openai_params(self, model: str) -> list:
"""
Only some together models support response_format / tool calling
Docs: https://docs.together.ai/docs/json-mode
"""
supports_function_calling: Optional[bool] = None
try:
model_info = get_model_info(model, custom_llm_provider="together_ai")
supports_function_calling = model_info.get(
"supports_function_calling", False
)
except Exception as e:
verbose_logger.debug(f"Error getting supported openai params: {e}")
pass
optional_params = super().get_supported_openai_params(model)
if supports_function_calling is not True:
verbose_logger.warning(
"Only some together models support function calling/response_format. Docs - https://docs.together.ai/docs/function-calling"
)
optional_params.remove("tools")
optional_params.remove("tool_choice")
optional_params.remove("function_call")
optional_params.remove("response_format")
return optional_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
mapped_openai_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
if "response_format" in mapped_openai_params and mapped_openai_params[
"response_format"
] == {"type": "text"}:
mapped_openai_params.pop("response_format")
return mapped_openai_params

View file

@ -29,6 +29,6 @@ class TogetherAIRerankConfig:
return RerankResponse( return RerankResponse(
id=response.get("id") or str(uuid.uuid4()), id=response.get("id") or str(uuid.uuid4()),
results=_results, results=_results, # type: ignore
meta=rerank_meta, meta=rerank_meta,
) # Return response ) # Return response

View file

@ -12,7 +12,8 @@
"supports_vision": true, "supports_vision": true,
"supports_audio_input": true, "supports_audio_input": true,
"supports_audio_output": true, "supports_audio_output": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"gpt-4": { "gpt-4": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4818,7 +4819,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"amazon.nova-lite-v1:0": { "amazon.nova-lite-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4830,7 +4832,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"amazon.nova-pro-v1:0": { "amazon.nova-pro-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4842,7 +4845,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-sonnet-20240229-v1:0": { "anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4876,7 +4880,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_assistant_prefill": true "supports_assistant_prefill": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-haiku-20240307-v1:0": { "anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4898,7 +4903,8 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_assistant_prefill": true, "supports_assistant_prefill": true,
"supports_function_calling": true "supports_function_calling": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-opus-20240229-v1:0": { "anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4911,139 +4917,6 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true
}, },
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
},
"us.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
},
"eu.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 100000, "max_input_tokens": 100000,
@ -6097,6 +5970,30 @@
"litellm_provider": "together_ai", "litellm_provider": "together_ai",
"mode": "embedding" "mode": "embedding"
}, },
"together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
"input_cost_per_token": 0.00000018,
"output_cost_per_token": 0.00000018,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"input_cost_per_token": 0.00000088,
"output_cost_per_token": 0.00000088,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
"input_cost_per_token": 0.0000035,
"output_cost_per_token": 0.0000035,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1": { "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"input_cost_per_token": 0.0000006, "input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006, "output_cost_per_token": 0.0000006,

View file

@ -39,10 +39,10 @@ model_list:
access_groups: ["private-openai-models"] access_groups: ["private-openai-models"]
router_settings: router_settings:
routing_strategy: usage-based-routing-v2 # routing_strategy: usage-based-routing-v2
#redis_url: "os.environ/REDIS_URL" #redis_url: "os.environ/REDIS_URL"
redis_host: "os.environ/REDIS_HOST" redis_host: "os.environ/REDIS_HOST"
redis_port: "os.environ/REDIS_PORT" redis_port: "os.environ/REDIS_PORT"
litellm_settings: litellm_settings:
success_callback: ["langsmith"] callbacks: ["datadog"]

View file

@ -7,6 +7,7 @@ import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
@ -21,6 +22,7 @@ cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank() azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank() jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler()
################################################# #################################################
@ -70,7 +72,7 @@ async def arerank(
@client @client
def rerank( def rerank( # noqa: PLR0915
model: str, model: str,
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
@ -268,6 +270,27 @@ def rerank(
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async, _is_async=_is_async,
) )
elif _custom_llm_provider == "bedrock":
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("BEDROCK_API_BASE") # type: ignore
)
response = bedrock_rerank.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
optional_params=optional_params.model_dump(exclude_unset=True),
api_base=api_base,
logging_obj=litellm_logging_obj,
)
else: else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -2,6 +2,7 @@ import json
from typing import Any, List, Literal, Optional, TypedDict, Union from typing import Any, List, Literal, Optional, TypedDict, Union
from typing_extensions import ( from typing_extensions import (
TYPE_CHECKING,
Protocol, Protocol,
Required, Required,
Self, Self,
@ -14,8 +15,13 @@ from typing_extensions import (
from .openai import ChatCompletionToolCallChunk from .openai import ChatCompletionToolCallChunk
class SystemContentBlock(TypedDict): class CachePointBlock(TypedDict, total=False):
type: Literal["default"]
class SystemContentBlock(TypedDict, total=False):
text: str text: str
cachePoint: CachePointBlock
class SourceBlock(TypedDict): class SourceBlock(TypedDict):
@ -58,6 +64,7 @@ class ContentBlock(TypedDict, total=False):
document: DocumentBlock document: DocumentBlock
toolResult: ToolResultBlock toolResult: ToolResultBlock
toolUse: ToolUseBlock toolUse: ToolUseBlock
cachePoint: CachePointBlock
class MessageBlock(TypedDict): class MessageBlock(TypedDict):
@ -312,3 +319,71 @@ class AmazonStability3TextToImageResponse(TypedDict, total=False):
images: List[str] images: List[str]
seeds: List[str] seeds: List[str]
finish_reasons: List[str] finish_reasons: List[str]
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
from pydantic import BaseModel
class BedrockPreparedRequest(TypedDict):
"""
Internal/Helper class for preparing the request for bedrock image generation
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
class BedrockRerankTextQuery(TypedDict):
text: str
class BedrockRerankQuery(TypedDict):
textQuery: BedrockRerankTextQuery
type: Literal["TEXT"]
class BedrockRerankModelConfiguration(TypedDict, total=False):
modelArn: Required[str]
modelConfiguration: dict
class BedrockRerankBedrockRerankingConfiguration(TypedDict):
modelConfiguration: BedrockRerankModelConfiguration
numberOfResults: int
class BedrockRerankConfiguration(TypedDict):
bedrockRerankingConfiguration: BedrockRerankBedrockRerankingConfiguration
type: Literal["BEDROCK_RERANKING_MODEL"]
class BedrockRerankTextDocument(TypedDict, total=False):
text: str
class BedrockRerankInlineDocumentSource(TypedDict, total=False):
jsonDocument: dict
textDocument: BedrockRerankTextDocument
type: Literal["TEXT", "JSON"]
class BedrockRerankSource(TypedDict):
inlineDocumentSource: BedrockRerankInlineDocumentSource
type: Literal["INLINE"]
class BedrockRerankRequest(TypedDict):
"""
Request for Bedrock Rerank API
"""
queries: List[BedrockRerankQuery]
rerankingConfiguration: BedrockRerankConfiguration
sources: List[BedrockRerankSource]

View file

@ -36,9 +36,14 @@ class RerankResponseMeta(TypedDict, total=False):
tokens: RerankTokens tokens: RerankTokens
class RerankResponseResult(TypedDict):
index: int
relevance_score: float
class RerankResponse(BaseModel): class RerankResponse(BaseModel):
id: str id: str
results: List[dict] # Contains index and relevance_score results: List[RerankResponseResult] # Contains index and relevance_score
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr # Define private attributes using PrivateAttr

View file

@ -1874,22 +1874,11 @@ def supports_prompt_caching(
Raises: Raises:
Exception: If the given model is not found or there's an error in retrieval. Exception: If the given model is not found or there's an error in retrieval.
""" """
try: return _supports_factory(
model, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model,
model=model, custom_llm_provider=custom_llm_provider custom_llm_provider=custom_llm_provider,
) key="supports_prompt_caching",
)
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_prompt_caching", False) is True:
return True
return False
except Exception as e:
raise Exception(
f"Model not found or error in checking prompt caching support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
def supports_vision(model: str, custom_llm_provider: Optional[str] = None) -> bool: def supports_vision(model: str, custom_llm_provider: Optional[str] = None) -> bool:

View file

@ -12,7 +12,8 @@
"supports_vision": true, "supports_vision": true,
"supports_audio_input": true, "supports_audio_input": true,
"supports_audio_output": true, "supports_audio_output": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"gpt-4": { "gpt-4": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4818,7 +4819,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"amazon.nova-lite-v1:0": { "amazon.nova-lite-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4830,7 +4832,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"amazon.nova-pro-v1:0": { "amazon.nova-pro-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4842,7 +4845,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true "supports_pdf_input": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-sonnet-20240229-v1:0": { "anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4876,7 +4880,8 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_assistant_prefill": true "supports_assistant_prefill": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-haiku-20240307-v1:0": { "anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4898,7 +4903,8 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_assistant_prefill": true, "supports_assistant_prefill": true,
"supports_function_calling": true "supports_function_calling": true,
"supports_prompt_caching": true
}, },
"anthropic.claude-3-opus-20240229-v1:0": { "anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -4911,139 +4917,6 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true "supports_vision": true
}, },
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
},
"us.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
},
"eu.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 100000, "max_input_tokens": 100000,
@ -6097,6 +5970,30 @@
"litellm_provider": "together_ai", "litellm_provider": "together_ai",
"mode": "embedding" "mode": "embedding"
}, },
"together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
"input_cost_per_token": 0.00000018,
"output_cost_per_token": 0.00000018,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"input_cost_per_token": 0.00000088,
"output_cost_per_token": 0.00000088,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
"input_cost_per_token": 0.0000035,
"output_cost_per_token": 0.0000035,
"litellm_provider": "together_ai",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"mode": "chat"
},
"together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1": { "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"input_cost_per_token": 0.0000006, "input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006, "output_cost_per_token": 0.0000006,

View file

@ -23,6 +23,34 @@ from litellm.utils import (
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
def _usage_format_tests(usage: litellm.Usage):
"""
OpenAI prompt caching
- prompt_tokens = sum of non-cache hit tokens + cache-hit tokens
- total_tokens = prompt_tokens + completion_tokens
Example
```
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {
"cached_tokens": 1920
},
"completion_tokens_details": {
"reasoning_tokens": 0
}
# ANTHROPIC_ONLY #
"cache_creation_input_tokens": 0
}
```
"""
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
class BaseLLMChatTest(ABC): class BaseLLMChatTest(ABC):
""" """
Abstract base test class that enforces a common test across all test classes. Abstract base test class that enforces a common test across all test classes.
@ -273,6 +301,78 @@ class BaseLLMChatTest(ABC):
response = litellm.completion(**base_completion_call_args, messages=messages) response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None assert response is not None
def test_prompt_caching(self):
litellm.set_verbose = True
from litellm.utils import supports_prompt_caching
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args()
if not supports_prompt_caching(base_completion_call_args["model"], None):
print("Model does not support prompt caching")
pytest.skip("Model does not support prompt caching")
try:
for _ in range(2):
response = litellm.completion(
**base_completion_call_args,
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
temperature=0.2,
max_tokens=10,
)
_usage_format_tests(response.usage)
print("response=", response)
print("response.usage=", response.usage)
_usage_format_tests(response.usage)
assert "prompt_tokens_details" in response.usage
assert response.usage.prompt_tokens_details.cached_tokens > 0
except litellm.InternalServerError:
pass
@pytest.fixture @pytest.fixture
def pdf_messages(self): def pdf_messages(self):
import base64 import base64

View file

@ -79,6 +79,7 @@ class BaseLLMRerankTest(ABC):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode): async def test_basic_rerank(self, sync_mode):
litellm.set_verbose = True
rerank_call_args = self.get_base_rerank_call_args() rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider() custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True: if sync_mode is True:
@ -86,7 +87,7 @@ class BaseLLMRerankTest(ABC):
**rerank_call_args, **rerank_call_args,
query="hello", query="hello",
documents=["hello", "world"], documents=["hello", "world"],
top_n=3, top_n=2,
) )
print("re rank response: ", response) print("re rank response: ", response)
@ -102,7 +103,7 @@ class BaseLLMRerankTest(ABC):
**rerank_call_args, **rerank_call_args,
query="hello", query="hello",
documents=["hello", "world"], documents=["hello", "world"],
top_n=3, top_n=2,
) )
print("async re rank response: ", response) print("async re rank response: ", response)

View file

@ -666,7 +666,7 @@ from litellm import completion
class TestAnthropicCompletion(BaseLLMChatTest): class TestAnthropicCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict: def get_base_completion_call_args(self) -> dict:
return {"model": "claude-3-haiku-20240307"} return {"model": "anthropic/claude-3-5-sonnet-20240620"}
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""

View file

@ -1,3 +1,7 @@
"""
Tests Bedrock Completion + Rerank endpoints
"""
# @pytest.mark.skip(reason="AWS Suspended Account") # @pytest.mark.skip(reason="AWS Suspended Account")
import os import os
import sys import sys
@ -31,6 +35,7 @@ from litellm.llms.bedrock.chat import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
from base_llm_unit_tests import BaseLLMChatTest from base_llm_unit_tests import BaseLLMChatTest
from base_rerank_unit_tests import BaseLLMRerankTest
# litellm.num_retries = 3 # litellm.num_retries = 3
litellm.cache = None litellm.cache = None
@ -1971,13 +1976,67 @@ def test_bedrock_base_model_helper():
assert model == "us.amazon.nova-pro-v1:0" assert model == "us.amazon.nova-pro-v1:0"
@pytest.mark.parametrize(
"messages, expected_cache_control",
[
(
[ # test system prompt cache
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
True,
),
(
[ # test user prompt cache
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
},
],
True,
),
],
)
def test_bedrock_prompt_caching_message(messages, expected_cache_control):
import litellm
import json
transformed_messages = litellm.AmazonConverseConfig()._transform_request(
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
messages=messages,
optional_params={},
litellm_params={},
)
if expected_cache_control:
assert "cachePoint" in json.dumps(transformed_messages)
else:
assert "cachePoint" not in json.dumps(transformed_messages)
class TestBedrockConverseChat(BaseLLMChatTest): class TestBedrockConverseChat(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict: def get_base_completion_call_args(self) -> dict:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.add_known_models() litellm.add_known_models()
return { return {
"model": "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0", "model": "bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
} }
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
@ -1991,3 +2050,19 @@ class TestBedrockConverseChat(BaseLLMChatTest):
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
""" """
pass pass
def test_prompt_caching(self):
"""
Remove override once we have access to Bedrock prompt caching
"""
pass
class TestBedrockRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.BEDROCK
def get_base_rerank_call_args(self) -> dict:
return {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
}

View file

@ -0,0 +1,58 @@
"""
Test TogetherAI LLM
"""
from base_llm_unit_tests import BaseLLMChatTest
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import pytest
class TestTogetherAI(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm.set_verbose = True
return {"model": "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
def test_multilingual_requests(self):
"""
Mistral API raises a 400 BadRequest error when the request contains invalid utf-8 sequences.
"""
pass
@pytest.mark.parametrize(
"model, expected_bool",
[
("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", True),
("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", False),
],
)
def test_get_supported_response_format_together_ai(
self, model: str, expected_bool: bool
) -> None:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
optional_params = litellm.get_supported_openai_params(
model, custom_llm_provider="together_ai"
)
# Mapped provider
assert isinstance(optional_params, list)
if expected_bool:
assert "response_format" in optional_params
assert "tools" in optional_params
else:
assert "response_format" not in optional_params
assert "tools" not in optional_params

View file

@ -197,7 +197,7 @@ async def test_async_log_cache_hit_on_callbacks():
), ),
( (
CallTypes.rerank.value, CallTypes.rerank.value,
{"id": "test", "results": [{"index": 0, "score": 0.9}]}, {"id": "test", "results": [{"index": 0, "relevance_score": 0.9}]},
RerankResponse, RerankResponse,
), ),
( (

View file

@ -38,76 +38,6 @@ def _usage_format_tests(usage: litellm.Usage):
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-3-5-sonnet-20240620",
# "openai/gpt-4o",
# "deepseek/deepseek-chat",
],
)
def test_prompt_caching_model(model):
try:
for _ in range(2):
response = litellm.completion(
model=model,
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
],
temperature=0.2,
max_tokens=10,
)
_usage_format_tests(response.usage)
print("response=", response)
print("response.usage=", response.usage)
_usage_format_tests(response.usage)
assert "prompt_tokens_details" in response.usage
assert response.usage.prompt_tokens_details.cached_tokens > 0
except litellm.InternalServerError:
pass
def test_supports_prompt_caching(): def test_supports_prompt_caching():
from litellm.utils import supports_prompt_caching from litellm.utils import supports_prompt_caching

View file

@ -185,3 +185,22 @@ async def test_log_db_metrics_failure_error_types(exception, should_log):
else: else:
# Assert failure was NOT logged for non-DB errors # Assert failure was NOT logged for non-DB errors
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called() mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()
@pytest.mark.asyncio
async def test_dd_log_db_spend_failure_metrics():
from litellm._service_logger import ServiceLogging
from litellm.integrations.datadog.datadog import DataDogLogger
dd_logger = DataDogLogger()
with patch.object(dd_logger, "async_service_failure_hook", new_callable=AsyncMock):
service_logging_obj = ServiceLogging()
litellm.service_callback = [dd_logger]
await service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
call_type="test_call_type",
error="test_error",
duration=1.0,
)