mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix calling claude via invoke route + response_format support for claude on invoke route (#8908)
* fix(anthropic_claude3_transformation.py): fix amazon anthropic claude 3 tool calling transformation on invoke route move to using anthropic config as base * fix(utils.py): expose anthropic config via providerconfigmanager * fix(llm_http_handler.py): support json mode on async completion calls * fix(invoke_handler/make_call): support json mode for anthropic called via bedrock invoke * fix(anthropic/): handle 'response_format: {"type": "text"}` + migrate amazon claude 3 invoke config to inherit from anthropic config Prevents error when passing in 'response_format: {"type": "text"} * test: fix test * fix(utils.py): fix base invoke provider check * fix(anthropic_claude3_transformation.py): don't pass 'stream' param * fix: fix linting errors * fix(converse_transformation.py): handle response_format type=text for converse
This commit is contained in:
parent
8f86959c32
commit
a65bfab697
18 changed files with 444 additions and 139 deletions
|
@ -53,6 +53,7 @@ from litellm.constants import (
|
|||
cohere_embedding_models,
|
||||
bedrock_embedding_models,
|
||||
known_tokenizer_config,
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailItem
|
||||
from litellm.proxy._types import (
|
||||
|
@ -361,17 +362,7 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-2-11b-instruct-v1:0",
|
||||
"meta.llama3-2-90b-instruct-v1:0",
|
||||
]
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"amazon",
|
||||
"meta",
|
||||
"llama",
|
||||
"ai21",
|
||||
"nova",
|
||||
"deepseek_r1",
|
||||
]
|
||||
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
open_ai_text_completion_models: List = []
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
ROUTER_MAX_FALLBACKS = 5
|
||||
DEFAULT_BATCH_SIZE = 512
|
||||
|
@ -320,6 +320,17 @@ baseten_models: List = [
|
|||
"31dxrj3",
|
||||
] # FALCON 7B # WizardLM # Mosaic ML
|
||||
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"amazon",
|
||||
"meta",
|
||||
"llama",
|
||||
"ai21",
|
||||
"nova",
|
||||
"deepseek_r1",
|
||||
]
|
||||
|
||||
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
||||
cohere_embedding_models: List = [
|
||||
|
|
|
@ -308,7 +308,6 @@ class AnthropicConfig(BaseConfig):
|
|||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
|
@ -342,6 +341,10 @@ class AnthropicConfig(BaseConfig):
|
|||
optional_params["top_p"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
|
|
|
@ -317,6 +317,7 @@ class BaseConfig(ABC):
|
|||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -330,6 +331,7 @@ class BaseConfig(ABC):
|
|||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -2,13 +2,14 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
@ -223,6 +224,60 @@ class BaseAWSLLM:
|
|||
# Catch any unexpected errors and return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 3 scenarions:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = BaseAWSLLM._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
else:
|
||||
for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def _get_aws_region_name(
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> str:
|
||||
|
|
|
@ -206,7 +206,12 @@ class AmazonConverseConfig(BaseConfig):
|
|||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format":
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
if "response_schema" in value:
|
||||
|
|
|
@ -226,6 +226,7 @@ async def make_call(
|
|||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=False,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
|
@ -311,6 +312,7 @@ def make_sync_call(
|
|||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=True,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
elif bedrock_invoke_provider == "deepseek_r1":
|
||||
|
@ -1149,27 +1151,6 @@ class BedrockLLM(BaseAWSLLM):
|
|||
)
|
||||
return streaming_response
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 2 scenarions:
|
||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
"""
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = BedrockLLM._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
|
@ -1524,6 +1505,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
|||
self,
|
||||
model: str,
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
||||
|
@ -1534,6 +1516,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
|||
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> ModelResponseStream:
|
||||
|
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
|||
|
||||
import litellm
|
||||
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
|
||||
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
|
@ -57,9 +59,7 @@ class AmazonAnthropicConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
|
@ -69,7 +69,13 @@ class AmazonAnthropicConfig:
|
|||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Optional
|
|||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ else:
|
|||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
@ -28,18 +28,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
|||
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
@ -47,21 +37,14 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
|||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
) -> dict:
|
||||
return AnthropicConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
|
@ -71,7 +54,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_anthropic_request = litellm.AnthropicConfig().transform_request(
|
||||
_anthropic_request = AnthropicConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
|
@ -80,6 +64,7 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
|||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
_anthropic_request.pop("stream", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
|
@ -99,7 +84,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.AnthropicConfig().transform_response(
|
||||
return AnthropicConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import time
|
||||
import urllib.parse
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -461,6 +461,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
|
@ -475,6 +476,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -493,6 +495,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
|
@ -509,6 +512,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -527,56 +531,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 3 scenarions:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def get_bedrock_model_id(
|
||||
self,
|
||||
optional_params: dict,
|
||||
|
|
|
@ -159,6 +159,7 @@ class BaseLLMHTTPHandler:
|
|||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: bool = False,
|
||||
):
|
||||
if client is None:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
|
@ -190,6 +191,7 @@ class BaseLLMHTTPHandler:
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -211,6 +213,7 @@ class BaseLLMHTTPHandler:
|
|||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
|
@ -286,6 +289,7 @@ class BaseLLMHTTPHandler:
|
|||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -309,6 +313,7 @@ class BaseLLMHTTPHandler:
|
|||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
|
@ -327,6 +332,7 @@ class BaseLLMHTTPHandler:
|
|||
data=data,
|
||||
messages=messages,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
|
@ -380,6 +386,7 @@ class BaseLLMHTTPHandler:
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def make_sync_call(
|
||||
|
@ -453,6 +460,7 @@ class BaseLLMHTTPHandler:
|
|||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
):
|
||||
if provider_config.has_custom_stream_wrapper is True:
|
||||
return provider_config.get_async_custom_stream_wrapper(
|
||||
|
@ -464,6 +472,7 @@ class BaseLLMHTTPHandler:
|
|||
data=data,
|
||||
messages=messages,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
completion_stream, _response_headers = await self.make_async_call_stream_helper(
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -16,7 +16,10 @@ model_list:
|
|||
api_key: os.environ/COHERE_API_KEY
|
||||
- model_name: bedrock-claude-3-7
|
||||
litellm_params:
|
||||
model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
- model_name: bedrock-claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["langfuse"]
|
|
@ -3240,6 +3240,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
bedrock_base_model = BedrockModelInfo.get_base_model(model)
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
|
@ -3253,8 +3254,9 @@ def get_optional_params( # noqa: PLR0915
|
|||
messages=messages,
|
||||
)
|
||||
|
||||
elif "anthropic" in model and bedrock_route == "invoke":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
elif "anthropic" in bedrock_base_model and bedrock_route == "invoke":
|
||||
if bedrock_base_model.startswith("anthropic.claude-3"):
|
||||
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3267,10 +3269,17 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif provider_config is not None:
|
||||
optional_params = provider_config.map_openai_params(
|
||||
|
@ -6158,13 +6167,19 @@ class ProviderConfigManager:
|
|||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
||||
model
|
||||
model=model
|
||||
)
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
return litellm.AmazonConverseConfig()
|
||||
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
||||
return litellm.AmazonTitanConfig()
|
||||
elif bedrock_invoke_provider == "anthropic":
|
||||
if base_model.startswith("anthropic.claude-3"):
|
||||
return litellm.AmazonAnthropicClaude3Config()
|
||||
else:
|
||||
return litellm.AmazonAnthropicConfig()
|
||||
elif (
|
||||
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
||||
): # amazon / meta llms
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
ProviderConfigManager,
|
||||
)
|
||||
from typing import Union
|
||||
|
||||
|
@ -247,12 +248,45 @@ class BaseLLMChatTest(ABC):
|
|||
response_format=response_format,
|
||||
)
|
||||
|
||||
print(response)
|
||||
print(f"response={response}")
|
||||
|
||||
# OpenAI guarantees that the JSON schema is returned in the content
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_response_format_type_text(self):
|
||||
"""
|
||||
Test that the response format type text does not lead to tool calls
|
||||
"""
|
||||
from litellm import LlmProviders
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
litellm.set_verbose = True
|
||||
|
||||
_, provider, _, _ = litellm.get_llm_provider(
|
||||
model=base_completion_call_args["model"]
|
||||
)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
base_completion_call_args["model"], LlmProviders(provider)
|
||||
)
|
||||
|
||||
print(f"provider_config={provider_config}")
|
||||
|
||||
translated_params = provider_config.map_openai_params(
|
||||
non_default_params={"response_format": {"type": "text"}},
|
||||
optional_params={},
|
||||
model=base_completion_call_args["model"],
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert "tool_choice" not in translated_params
|
||||
assert (
|
||||
"tools" not in translated_params
|
||||
), f"Got tools={translated_params['tools']}, expected no tools"
|
||||
|
||||
print(f"translated_params={translated_params}")
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_pydantic_obj(self):
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -22,17 +22,9 @@ class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
|||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_non_json_tests(self, request):
|
||||
if not "json" in request.function.__name__.lower():
|
||||
pytest.skip(
|
||||
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||
)
|
||||
|
||||
|
||||
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
litellm._turn_on_debug()
|
||||
return {
|
||||
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
|
||||
}
|
||||
|
|
|
@ -319,12 +319,15 @@ def test_all_model_configs():
|
|||
) == {"max_tokens": 10}
|
||||
|
||||
assert (
|
||||
"max_completion_tokens" in AmazonAnthropicConfig().get_supported_openai_params()
|
||||
"max_completion_tokens"
|
||||
in AmazonAnthropicConfig().get_supported_openai_params(model="")
|
||||
)
|
||||
|
||||
assert AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params={"max_completion_tokens": 10},
|
||||
optional_params={},
|
||||
model="",
|
||||
drop_params=False,
|
||||
) == {"max_tokens_to_sample": 10}
|
||||
|
||||
from litellm.llms.databricks.chat.handler import DatabricksConfig
|
||||
|
|
|
@ -1114,3 +1114,257 @@ def test_anthropic_thinking_param(model, expected_thinking):
|
|||
assert "thinking" in optional_params
|
||||
else:
|
||||
assert "thinking" not in optional_params
|
||||
|
||||
|
||||
def test_bedrock_invoke_anthropic_max_tokens():
|
||||
passed_params = {
|
||||
"model": "invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"functions": None,
|
||||
"function_call": None,
|
||||
"temperature": 0.8,
|
||||
"top_p": None,
|
||||
"n": 1,
|
||||
"stream": False,
|
||||
"stream_options": None,
|
||||
"stop": None,
|
||||
"max_tokens": None,
|
||||
"max_completion_tokens": 1024,
|
||||
"modalities": None,
|
||||
"prediction": None,
|
||||
"audio": None,
|
||||
"presence_penalty": None,
|
||||
"frequency_penalty": None,
|
||||
"logit_bias": None,
|
||||
"user": None,
|
||||
"custom_llm_provider": "bedrock",
|
||||
"response_format": {"type": "text"},
|
||||
"seed": None,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_plan",
|
||||
"description": "Generate a plan to execute the task using only the tools outlined in your context.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type of step to execute",
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the tool to use for this step",
|
||||
},
|
||||
"tool_input": {
|
||||
"type": "object",
|
||||
"description": "The input to pass to the tool. Make sure this complies with the schema for the tool.",
|
||||
},
|
||||
"tool_output": {
|
||||
"type": "object",
|
||||
"description": "(Optional) The output from the tool if needed for future steps. Make sure this complies with the schema for the tool.",
|
||||
},
|
||||
},
|
||||
"required": ["type"],
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_wire_tool",
|
||||
"description": "Create a wire transfer with complete wire instructions",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the company receiving the investment",
|
||||
},
|
||||
"investment_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the investment memo",
|
||||
},
|
||||
"dollar_amount": {
|
||||
"type": "number",
|
||||
"description": "The amount to wire in USD",
|
||||
},
|
||||
"wiring_instructions": {
|
||||
"type": "object",
|
||||
"description": "Complete bank account and routing information for the wire",
|
||||
"properties": {
|
||||
"account_name": {
|
||||
"type": "string",
|
||||
"description": "Name on the bank account",
|
||||
},
|
||||
"address_1": {
|
||||
"type": "string",
|
||||
"description": "Primary address line",
|
||||
},
|
||||
"address_2": {
|
||||
"type": "string",
|
||||
"description": "Secondary address line (optional)",
|
||||
},
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"zip": {"type": "string"},
|
||||
"country": {"type": "string", "default": "US"},
|
||||
"bank_name": {"type": "string"},
|
||||
"account_number": {"type": "string"},
|
||||
"routing_number": {"type": "string"},
|
||||
"account_type": {
|
||||
"type": "string",
|
||||
"enum": ["checking", "savings"],
|
||||
"default": "checking",
|
||||
},
|
||||
"swift_code": {
|
||||
"type": "string",
|
||||
"description": "Required for international wires",
|
||||
},
|
||||
"iban": {
|
||||
"type": "string",
|
||||
"description": "Required for some international wires",
|
||||
},
|
||||
"bank_city": {"type": "string"},
|
||||
"bank_state": {"type": "string"},
|
||||
"bank_country": {"type": "string", "default": "US"},
|
||||
"bank_to_bank_instructions": {
|
||||
"type": "string",
|
||||
"description": "Additional instructions for the bank (optional)",
|
||||
},
|
||||
"intermediary_bank_name": {
|
||||
"type": "string",
|
||||
"description": "Name of intermediary bank if required (optional)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"account_name",
|
||||
"address_1",
|
||||
"country",
|
||||
"bank_name",
|
||||
"account_number",
|
||||
"routing_number",
|
||||
"account_type",
|
||||
"bank_country",
|
||||
],
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"company_id",
|
||||
"investment_id",
|
||||
"dollar_amount",
|
||||
"wiring_instructions",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_companies",
|
||||
"description": "Search for companies by name or other criteria to get their IDs",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Name or part of name to search for",
|
||||
},
|
||||
"batch": {
|
||||
"type": "string",
|
||||
"description": 'Optional batch filter (e.g., "W21", "S22")',
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"live",
|
||||
"dead",
|
||||
"adrift",
|
||||
"exited",
|
||||
"went_public",
|
||||
"all",
|
||||
],
|
||||
"description": "Filter by company status",
|
||||
"default": "live",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
"output_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"description": "Success or error status",
|
||||
},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"description": "List of companies matching the search criteria",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "integer",
|
||||
"description": "Company ID to use in other API calls",
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
"batch": {"type": "string"},
|
||||
"status": {"type": "string"},
|
||||
"valuation": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"founders": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"results_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of companies returned",
|
||||
},
|
||||
"total_matches": {
|
||||
"type": "integer",
|
||||
"description": "Total number of matches found",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"tool_choice": None,
|
||||
"max_retries": 0,
|
||||
"logprobs": None,
|
||||
"top_logprobs": None,
|
||||
"extra_headers": None,
|
||||
"api_version": None,
|
||||
"parallel_tool_calls": None,
|
||||
"drop_params": True,
|
||||
"reasoning_effort": None,
|
||||
"additional_drop_params": None,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an AI assistant that helps prepare a wire for a pro rata investment.",
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "hi"}]},
|
||||
],
|
||||
"thinking": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
optional_params = get_optional_params(**passed_params)
|
||||
print(f"optional_params: {optional_params}")
|
||||
|
||||
assert "max_tokens_to_sample" not in optional_params
|
||||
assert optional_params["max_tokens"] == 1024
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue