QA: ensure all bedrock regional models have same supported_ as base + Anthropic nested pydantic object support (#7844)

* build: ensure all regional bedrock models have same supported values as base bedrock model

prevents drift

* test(base_llm_unit_tests.py): add testing for nested pydantic objects

* fix(test_utils.py): add test_get_potential_model_names

* fix(anthropic/chat/transformation.py): support nested pydantic objects

Fixes https://github.com/BerriAI/litellm/issues/7755
This commit is contained in:
Krish Dholakia 2025-01-17 19:49:12 -08:00 committed by GitHub
parent 37ed49fe72
commit 6eb2346fd6
12 changed files with 259 additions and 62 deletions

View file

@ -8,6 +8,7 @@ import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
from litellm.llms.base_llm.base_utils import type_to_response_format_param
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
@ -94,6 +95,13 @@ class AnthropicConfig(BaseConfig):
"user",
]
def get_json_schema_from_pydantic_object(
self, response_format: Union[Any, Dict, None]
) -> Optional[dict]:
return type_to_response_format_param(
response_format, ref_template="/$defs/{model}"
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
def get_cache_control_headers(self) -> dict:
return {
"anthropic-version": "2023-06-01",

View file

@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Type, Union
from openai.lib import _parsing, _pydantic
from pydantic import BaseModel
from litellm.types.utils import ModelInfoBase
@ -26,3 +29,39 @@ class BaseLLMModelInfo(ABC):
@abstractmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass
def type_to_response_format_param(
response_format: Optional[Union[Type[BaseModel], dict]],
ref_template: Optional[str] = None,
) -> Optional[dict]:
"""
Re-implementation of openai's 'type_to_response_format_param' function
Used for converting pydantic object to api schema.
"""
if response_format is None:
return None
if isinstance(response_format, dict):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
if not _parsing._completions.is_basemodel_type(response_format):
raise TypeError(f"Unsupported response_format type - {response_format}")
if ref_template is not None:
schema = response_format.model_json_schema(ref_template=ref_template)
else:
schema = _pydantic.to_strict_json_schema(response_format)
return {
"type": "json_schema",
"json_schema": {
"schema": schema,
"name": response_format.__name__,
"strict": True,
},
}

View file

@ -4,13 +4,25 @@ Common base config for all LLM providers
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Type,
Union,
)
import httpx
from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..base_utils import type_to_response_format_param
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -71,6 +83,11 @@ class BaseConfig(ABC):
and v is not None
}
def get_json_schema_from_pydantic_object(
self, response_format: Optional[Union[Type[BaseModel], dict]]
) -> Optional[dict]:
return type_to_response_format_param(response_format=response_format)
def should_fake_stream(
self,
model: Optional[str],

View file

@ -31,7 +31,14 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, add_dummy_tool, has_tool_call_blocks
from ..common_utils import BedrockError, get_bedrock_tool_name
from ..common_utils import (
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
class AmazonConverseConfig:
@ -573,13 +580,24 @@ class AmazonConverseConfig:
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/")[1]
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/")[1]
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model

View file

@ -42,16 +42,35 @@ class AmazonBedrockGlobalConfig:
optional_params[mapped_params[param]] = value
return optional_params
def get_all_regions(self) -> List[str]:
return (
self.get_us_regions()
+ self.get_eu_regions()
+ self.get_ap_regions()
+ self.get_ca_regions()
+ self.get_sa_regions()
)
def get_ap_regions(self) -> List[str]:
return ["ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1"]
def get_sa_regions(self) -> List[str]:
return ["sa-east-1"]
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
"""
return [
"eu-west-1",
"eu-west-2",
"eu-west-3",
"eu-central-1",
]
def get_ca_regions(self) -> List[str]:
return ["ca-central-1"]
def get_us_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
@ -59,6 +78,7 @@ class AmazonBedrockGlobalConfig:
return [
"us-east-2",
"us-east-1",
"us-west-1",
"us-west-2",
"us-gov-west-1",
]

View file

@ -5364,7 +5364,8 @@
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"bedrock/us-west-2/mistral.mistral-large-2402-v1:0": {
"max_tokens": 8191,
@ -5456,7 +5457,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true,
"supports_prompt_caching": true
"supports_prompt_caching": true,
"supports_response_schema": true
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5524,7 +5526,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5546,7 +5550,8 @@
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
"supports_function_calling": true,
"supports_prompt_caching": true
},
"us.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
@ -5591,7 +5596,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5612,7 +5619,10 @@
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"eu.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,

View file

@ -1,5 +1,5 @@
model_list:
- model_name: embedding-small
- model_name: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
litellm_params:
model: openai/text-embedding-3-small
model: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0

View file

@ -178,7 +178,10 @@ from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.base_utils import (
BaseLLMModelInfo,
type_to_response_format_param,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
@ -1474,7 +1477,7 @@ def create_pretrained_tokenizer(
try:
tokenizer = Tokenizer.from_pretrained(
identifier, revision=revision, auth_token=auth_token
identifier, revision=revision, auth_token=auth_token # type: ignore
)
except Exception as e:
verbose_logger.error(
@ -2773,11 +2776,26 @@ def get_optional_params( # noqa: PLR0915
message=f"Function calling is not supported by {custom_llm_provider}.",
)
if "response_format" in non_default_params:
non_default_params["response_format"] = type_to_response_format_param(
response_format=non_default_params["response_format"]
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
if "response_format" in non_default_params:
if provider_config is not None:
non_default_params["response_format"] = (
provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
)
else:
non_default_params["response_format"] = type_to_response_format_param(
response_format=non_default_params["response_format"]
)
if "tools" in non_default_params and isinstance(
non_default_params, list
): # fixes https://github.com/BerriAI/litellm/issues/4933
@ -2835,13 +2853,6 @@ def get_optional_params( # noqa: PLR0915
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
)
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
@ -4964,36 +4975,6 @@ def _should_retry(status_code: int):
return False
def type_to_response_format_param(
response_format: Optional[Union[Type[BaseModel], dict]],
) -> Optional[dict]:
"""
Re-implementation of openai's 'type_to_response_format_param' function
Used for converting pydantic object to api schema.
"""
if response_format is None:
return None
if isinstance(response_format, dict):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
if not _parsing._completions.is_basemodel_type(response_format):
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
"schema": _pydantic.to_strict_json_schema(response_format),
"name": response_format.__name__,
"strict": True,
},
}
def _get_retry_after_from_exception_header(
response_headers: Optional[httpx.Headers] = None,
):

View file

@ -5364,7 +5364,8 @@
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"bedrock/us-west-2/mistral.mistral-large-2402-v1:0": {
"max_tokens": 8191,
@ -5456,7 +5457,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true,
"supports_prompt_caching": true
"supports_prompt_caching": true,
"supports_response_schema": true
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5524,7 +5526,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5546,7 +5550,8 @@
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
"supports_function_calling": true,
"supports_prompt_caching": true
},
"us.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
@ -5591,7 +5596,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_assistant_prefill": true
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -5612,7 +5619,10 @@
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_response_schema": true
},
"eu.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,

View file

@ -259,6 +259,59 @@ class BaseLLMChatTest(ABC):
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_pydantic_obj_nested_obj(self):
litellm.set_verbose = True
from pydantic import BaseModel
from litellm.utils import supports_response_schema
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_nested_pydantic_obj(self):
from pydantic import BaseModel
from litellm.utils import supports_response_schema
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
class EventsList(BaseModel):
events: list[CalendarEvent]
messages = [
{"role": "user", "content": "List 5 important events in the XIX century"}
]
base_completion_call_args = self.get_base_completion_call_args()
if not supports_response_schema(base_completion_call_args["model"], None):
pytest.skip(
f"Model={base_completion_call_args['model']} does not support response schema"
)
try:
res = self.completion_function(
**base_completion_call_args,
messages=messages,
response_format=EventsList,
timeout=60,
)
assert res is not None
print(res.choices[0].message)
assert res.choices[0].message.content is not None
assert res.choices[0].message.tool_calls is None
except litellm.Timeout:
pytest.skip("Model took too long to respond")
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_format_stream(self):
"""

View file

@ -307,3 +307,35 @@ def test_get_model_info_custom_model_router():
info = get_model_info("openai/meta-llama/Meta-Llama-3-8B-Instruct")
print("info", info)
assert info is not None
def test_get_model_info_bedrock_models():
"""
Check for drift in base model info for bedrock models and regional model info for bedrock models.
"""
from litellm import AmazonConverseConfig
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
for k, v in litellm.model_cost.items():
if v["litellm_provider"] == "bedrock":
k = k.replace("*/", "")
potential_commitments = [
"1-month-commitment",
"3-month-commitment",
"6-month-commitment",
]
if any(commitment in k for commitment in potential_commitments):
for commitment in potential_commitments:
k = k.replace(f"{commitment}/", "")
base_model = AmazonConverseConfig()._get_base_model(k)
base_model_info = litellm.model_cost[base_model]
for base_model_key, base_model_value in base_model_info.items():
if base_model_key.startswith("supports_"):
assert (
base_model_key in v
), f"{base_model_key} is not in model cost map for {k}"
assert (
v[base_model_key] == base_model_value
), f"{base_model_key} is not equal to {base_model_value} for model {k}"

View file

@ -1471,3 +1471,12 @@ def test_pick_cheapest_chat_model_from_llm_provider():
assert len(pick_cheapest_chat_models_from_llm_provider("openai", n=3)) == 3
assert len(pick_cheapest_chat_models_from_llm_provider("unknown", n=1)) == 0
def test_get_potential_model_names():
from litellm.utils import _get_potential_model_names
assert _get_potential_model_names(
model="bedrock/ap-northeast-1/anthropic.claude-instant-v1",
custom_llm_provider="bedrock",
)