mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
QA: ensure all bedrock regional models have same supported_
as base + Anthropic nested pydantic object support (#7844)
* build: ensure all regional bedrock models have same supported values as base bedrock model prevents drift * test(base_llm_unit_tests.py): add testing for nested pydantic objects * fix(test_utils.py): add test_get_potential_model_names * fix(anthropic/chat/transformation.py): support nested pydantic objects Fixes https://github.com/BerriAI/litellm/issues/7755
This commit is contained in:
parent
37ed49fe72
commit
6eb2346fd6
12 changed files with 259 additions and 62 deletions
|
@ -8,6 +8,7 @@ import litellm
|
|||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
|
@ -94,6 +95,13 @@ class AnthropicConfig(BaseConfig):
|
|||
"user",
|
||||
]
|
||||
|
||||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Union[Any, Dict, None]
|
||||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(
|
||||
response_format, ref_template="/$defs/{model}"
|
||||
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
||||
|
||||
def get_cache_control_headers(self) -> dict:
|
||||
return {
|
||||
"anthropic-version": "2023-06-01",
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from openai.lib import _parsing, _pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.utils import ModelInfoBase
|
||||
|
||||
|
@ -26,3 +29,39 @@ class BaseLLMModelInfo(ABC):
|
|||
@abstractmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
|
||||
def type_to_response_format_param(
|
||||
response_format: Optional[Union[Type[BaseModel], dict]],
|
||||
ref_template: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Re-implementation of openai's 'type_to_response_format_param' function
|
||||
|
||||
Used for converting pydantic object to api schema.
|
||||
"""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if isinstance(response_format, dict):
|
||||
return response_format
|
||||
|
||||
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||
# a safe default behaviour but we know that at this point the `response_format`
|
||||
# can only be a `type`
|
||||
if not _parsing._completions.is_basemodel_type(response_format):
|
||||
raise TypeError(f"Unsupported response_format type - {response_format}")
|
||||
|
||||
if ref_template is not None:
|
||||
schema = response_format.model_json_schema(ref_template=ref_template)
|
||||
else:
|
||||
schema = _pydantic.to_strict_json_schema(response_format)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": schema,
|
||||
"name": response_format.__name__,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -4,13 +4,25 @@ Common base config for all LLM providers
|
|||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..base_utils import type_to_response_format_param
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
|
@ -71,6 +83,11 @@ class BaseConfig(ABC):
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Optional[Union[Type[BaseModel], dict]]
|
||||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(response_format=response_format)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
|
|
@ -31,7 +31,14 @@ from litellm.types.llms.openai import (
|
|||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper, add_dummy_tool, has_tool_call_blocks
|
||||
|
||||
from ..common_utils import BedrockError, get_bedrock_tool_name
|
||||
from ..common_utils import (
|
||||
AmazonBedrockGlobalConfig,
|
||||
BedrockError,
|
||||
get_bedrock_tool_name,
|
||||
)
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
|
||||
class AmazonConverseConfig:
|
||||
|
@ -573,13 +580,24 @@ class AmazonConverseConfig:
|
|||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/")[1]
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/")[1]
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if potential_region in self._supported_cross_region_inference_region():
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
|
|
@ -42,16 +42,35 @@ class AmazonBedrockGlobalConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_all_regions(self) -> List[str]:
|
||||
return (
|
||||
self.get_us_regions()
|
||||
+ self.get_eu_regions()
|
||||
+ self.get_ap_regions()
|
||||
+ self.get_ca_regions()
|
||||
+ self.get_sa_regions()
|
||||
)
|
||||
|
||||
def get_ap_regions(self) -> List[str]:
|
||||
return ["ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1"]
|
||||
|
||||
def get_sa_regions(self) -> List[str]:
|
||||
return ["sa-east-1"]
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1",
|
||||
"eu-west-2",
|
||||
"eu-west-3",
|
||||
"eu-central-1",
|
||||
]
|
||||
|
||||
def get_ca_regions(self) -> List[str]:
|
||||
return ["ca-central-1"]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
|
@ -59,6 +78,7 @@ class AmazonBedrockGlobalConfig:
|
|||
return [
|
||||
"us-east-2",
|
||||
"us-east-1",
|
||||
"us-west-1",
|
||||
"us-west-2",
|
||||
"us-gov-west-1",
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue