mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix calling claude via invoke route + response_format support for claude on invoke route (#8908)
* fix(anthropic_claude3_transformation.py): fix amazon anthropic claude 3 tool calling transformation on invoke route move to using anthropic config as base * fix(utils.py): expose anthropic config via providerconfigmanager * fix(llm_http_handler.py): support json mode on async completion calls * fix(invoke_handler/make_call): support json mode for anthropic called via bedrock invoke * fix(anthropic/): handle 'response_format: {"type": "text"}` + migrate amazon claude 3 invoke config to inherit from anthropic config Prevents error when passing in 'response_format: {"type": "text"} * test: fix test * fix(utils.py): fix base invoke provider check * fix(anthropic_claude3_transformation.py): don't pass 'stream' param * fix: fix linting errors * fix(converse_transformation.py): handle response_format type=text for converse
This commit is contained in:
parent
8f86959c32
commit
a65bfab697
18 changed files with 444 additions and 139 deletions
|
@ -53,6 +53,7 @@ from litellm.constants import (
|
||||||
cohere_embedding_models,
|
cohere_embedding_models,
|
||||||
bedrock_embedding_models,
|
bedrock_embedding_models,
|
||||||
known_tokenizer_config,
|
known_tokenizer_config,
|
||||||
|
BEDROCK_INVOKE_PROVIDERS_LITERAL,
|
||||||
)
|
)
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
@ -361,17 +362,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-2-11b-instruct-v1:0",
|
"meta.llama3-2-11b-instruct-v1:0",
|
||||||
"meta.llama3-2-90b-instruct-v1:0",
|
"meta.llama3-2-90b-instruct-v1:0",
|
||||||
]
|
]
|
||||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
|
||||||
"cohere",
|
|
||||||
"anthropic",
|
|
||||||
"mistral",
|
|
||||||
"amazon",
|
|
||||||
"meta",
|
|
||||||
"llama",
|
|
||||||
"ai21",
|
|
||||||
"nova",
|
|
||||||
"deepseek_r1",
|
|
||||||
]
|
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
open_ai_text_completion_models: List = []
|
open_ai_text_completion_models: List = []
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Literal
|
||||||
|
|
||||||
ROUTER_MAX_FALLBACKS = 5
|
ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
|
@ -320,6 +320,17 @@ baseten_models: List = [
|
||||||
"31dxrj3",
|
"31dxrj3",
|
||||||
] # FALCON 7B # WizardLM # Mosaic ML
|
] # FALCON 7B # WizardLM # Mosaic ML
|
||||||
|
|
||||||
|
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||||
|
"cohere",
|
||||||
|
"anthropic",
|
||||||
|
"mistral",
|
||||||
|
"amazon",
|
||||||
|
"meta",
|
||||||
|
"llama",
|
||||||
|
"ai21",
|
||||||
|
"nova",
|
||||||
|
"deepseek_r1",
|
||||||
|
]
|
||||||
|
|
||||||
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
||||||
cohere_embedding_models: List = [
|
cohere_embedding_models: List = [
|
||||||
|
|
|
@ -308,7 +308,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
|
@ -342,6 +341,10 @@ class AnthropicConfig(BaseConfig):
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "response_format" and isinstance(value, dict):
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
|
||||||
|
ignore_response_format_types = ["text"]
|
||||||
|
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||||
|
continue
|
||||||
|
|
||||||
json_schema: Optional[dict] = None
|
json_schema: Optional[dict] = None
|
||||||
if "response_schema" in value:
|
if "response_schema" in value:
|
||||||
json_schema = value["response_schema"]
|
json_schema = value["response_schema"]
|
||||||
|
|
|
@ -317,6 +317,7 @@ class BaseConfig(ABC):
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: list,
|
messages: list,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -330,6 +331,7 @@ class BaseConfig(ABC):
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: list,
|
messages: list,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,14 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL
|
||||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
|
||||||
|
@ -223,6 +224,60 @@ class BaseAWSLLM:
|
||||||
# Catch any unexpected errors and return None
|
# Catch any unexpected errors and return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_provider_from_model_path(
|
||||||
|
model_path: str,
|
||||||
|
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the provider from a model path with format: provider/model-name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The provider name, or None if no valid provider found
|
||||||
|
"""
|
||||||
|
parts = model_path.split("/")
|
||||||
|
if len(parts) >= 1:
|
||||||
|
provider = parts[0]
|
||||||
|
if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_bedrock_invoke_provider(
|
||||||
|
model: str,
|
||||||
|
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the bedrock provider from the model
|
||||||
|
|
||||||
|
handles 3 scenarions:
|
||||||
|
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||||
|
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||||
|
"""
|
||||||
|
if model.startswith("invoke/"):
|
||||||
|
model = model.replace("invoke/", "", 1)
|
||||||
|
|
||||||
|
_split_model = model.split(".")[0]
|
||||||
|
if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||||
|
|
||||||
|
# If not a known provider, check for pattern with two slashes
|
||||||
|
provider = BaseAWSLLM._get_provider_from_model_path(model)
|
||||||
|
if provider is not None:
|
||||||
|
return provider
|
||||||
|
|
||||||
|
# check if provider == "nova"
|
||||||
|
if "nova" in model:
|
||||||
|
return "nova"
|
||||||
|
else:
|
||||||
|
for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
if provider in model:
|
||||||
|
return provider
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_aws_region_name(
|
def _get_aws_region_name(
|
||||||
self, optional_params: dict, model: Optional[str] = None
|
self, optional_params: dict, model: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
@ -206,7 +206,12 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
messages: Optional[List[AllMessageValues]] = None,
|
messages: Optional[List[AllMessageValues]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "response_format":
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
|
||||||
|
ignore_response_format_types = ["text"]
|
||||||
|
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||||
|
continue
|
||||||
|
|
||||||
json_schema: Optional[dict] = None
|
json_schema: Optional[dict] = None
|
||||||
schema_name: str = ""
|
schema_name: str = ""
|
||||||
if "response_schema" in value:
|
if "response_schema" in value:
|
||||||
|
|
|
@ -226,6 +226,7 @@ async def make_call(
|
||||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||||
model=model,
|
model=model,
|
||||||
sync_stream=False,
|
sync_stream=False,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
completion_stream = decoder.aiter_bytes(
|
completion_stream = decoder.aiter_bytes(
|
||||||
response.aiter_bytes(chunk_size=1024)
|
response.aiter_bytes(chunk_size=1024)
|
||||||
|
@ -311,6 +312,7 @@ def make_sync_call(
|
||||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||||
model=model,
|
model=model,
|
||||||
sync_stream=True,
|
sync_stream=True,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
elif bedrock_invoke_provider == "deepseek_r1":
|
elif bedrock_invoke_provider == "deepseek_r1":
|
||||||
|
@ -1149,27 +1151,6 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_bedrock_invoke_provider(
|
|
||||||
model: str,
|
|
||||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
|
||||||
"""
|
|
||||||
Helper function to get the bedrock provider from the model
|
|
||||||
|
|
||||||
handles 2 scenarions:
|
|
||||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
|
||||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
|
||||||
"""
|
|
||||||
_split_model = model.split(".")[0]
|
|
||||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
|
||||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
|
||||||
|
|
||||||
# If not a known provider, check for pattern with two slashes
|
|
||||||
provider = BedrockLLM._get_provider_from_model_path(model)
|
|
||||||
if provider is not None:
|
|
||||||
return provider
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_provider_from_model_path(
|
def _get_provider_from_model_path(
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
@ -1524,6 +1505,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
sync_stream: bool,
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
||||||
|
@ -1534,6 +1516,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
||||||
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||||
streaming_response=None,
|
streaming_response=None,
|
||||||
sync_stream=sync_stream,
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> ModelResponseStream:
|
def _chunk_parser(self, chunk_data: dict) -> ModelResponseStream:
|
||||||
|
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
from .base_invoke_transformation import AmazonInvokeConfig
|
||||||
|
|
||||||
class AmazonAnthropicConfig:
|
|
||||||
|
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||||
|
|
||||||
|
@ -57,9 +59,7 @@ class AmazonAnthropicConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(
|
def get_supported_openai_params(self, model: str):
|
||||||
self,
|
|
||||||
):
|
|
||||||
return [
|
return [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"max_completion_tokens",
|
"max_completion_tokens",
|
||||||
|
@ -69,7 +69,13 @@ class AmazonAnthropicConfig:
|
||||||
"stream",
|
"stream",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
):
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
optional_params["max_tokens_to_sample"] = value
|
optional_params["max_tokens_to_sample"] = value
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||||
AmazonInvokeConfig,
|
AmazonInvokeConfig,
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,7 @@ else:
|
||||||
LiteLLMLoggingObj = Any
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig):
|
||||||
"""
|
"""
|
||||||
Reference:
|
Reference:
|
||||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||||
|
@ -28,18 +28,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
|
|
||||||
anthropic_version: str = "bedrock-2023-05-31"
|
anthropic_version: str = "bedrock-2023-05-31"
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str):
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"extra_headers",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
@ -47,21 +37,14 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
):
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
return AnthropicConfig.map_openai_params(
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
self,
|
||||||
optional_params["max_tokens"] = value
|
non_default_params,
|
||||||
if param == "tools":
|
optional_params,
|
||||||
optional_params["tools"] = value
|
model,
|
||||||
if param == "stream":
|
drop_params,
|
||||||
optional_params["stream"] = value
|
)
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop_sequences"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
|
@ -71,7 +54,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
_anthropic_request = litellm.AnthropicConfig().transform_request(
|
_anthropic_request = AnthropicConfig.transform_request(
|
||||||
|
self,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -80,6 +64,7 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
_anthropic_request.pop("model", None)
|
_anthropic_request.pop("model", None)
|
||||||
|
_anthropic_request.pop("stream", None)
|
||||||
if "anthropic_version" not in _anthropic_request:
|
if "anthropic_version" not in _anthropic_request:
|
||||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||||
|
|
||||||
|
@ -99,7 +84,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
return litellm.AnthropicConfig().transform_response(
|
return AnthropicConfig.transform_response(
|
||||||
|
self,
|
||||||
model=model,
|
model=model,
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -461,6 +461,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: list,
|
messages: list,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=None,
|
||||||
|
@ -475,6 +476,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
fake_stream=True if "ai21" in api_base else False,
|
fake_stream=True if "ai21" in api_base else False,
|
||||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||||
|
json_mode=json_mode,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
|
@ -493,6 +495,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: list,
|
messages: list,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
client = _get_httpx_client(params={})
|
client = _get_httpx_client(params={})
|
||||||
|
@ -509,6 +512,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
fake_stream=True if "ai21" in api_base else False,
|
fake_stream=True if "ai21" in api_base else False,
|
||||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||||
|
json_mode=json_mode,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
|
@ -527,56 +531,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_bedrock_invoke_provider(
|
|
||||||
model: str,
|
|
||||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
|
||||||
"""
|
|
||||||
Helper function to get the bedrock provider from the model
|
|
||||||
|
|
||||||
handles 3 scenarions:
|
|
||||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
|
||||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
|
||||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
|
||||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
|
||||||
"""
|
|
||||||
if model.startswith("invoke/"):
|
|
||||||
model = model.replace("invoke/", "", 1)
|
|
||||||
|
|
||||||
_split_model = model.split(".")[0]
|
|
||||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
|
||||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
|
||||||
|
|
||||||
# If not a known provider, check for pattern with two slashes
|
|
||||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
|
||||||
if provider is not None:
|
|
||||||
return provider
|
|
||||||
|
|
||||||
# check if provider == "nova"
|
|
||||||
if "nova" in model:
|
|
||||||
return "nova"
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_provider_from_model_path(
|
|
||||||
model_path: str,
|
|
||||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
|
||||||
"""
|
|
||||||
Helper function to get the provider from a model path with format: provider/model-name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[str]: The provider name, or None if no valid provider found
|
|
||||||
"""
|
|
||||||
parts = model_path.split("/")
|
|
||||||
if len(parts) >= 1:
|
|
||||||
provider = parts[0]
|
|
||||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
|
||||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_bedrock_model_id(
|
def get_bedrock_model_id(
|
||||||
self,
|
self,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
|
|
@ -159,6 +159,7 @@ class BaseLLMHTTPHandler:
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
json_mode: bool = False,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
async_httpx_client = get_async_httpx_client(
|
async_httpx_client = get_async_httpx_client(
|
||||||
|
@ -190,6 +191,7 @@ class BaseLLMHTTPHandler:
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -211,6 +213,7 @@ class BaseLLMHTTPHandler:
|
||||||
headers: Optional[dict] = {},
|
headers: Optional[dict] = {},
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
json_mode: bool = optional_params.pop("json_mode", False)
|
||||||
|
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
@ -286,6 +289,7 @@ class BaseLLMHTTPHandler:
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -309,6 +313,7 @@ class BaseLLMHTTPHandler:
|
||||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
|
@ -327,6 +332,7 @@ class BaseLLMHTTPHandler:
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
client=client,
|
client=client,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
completion_stream, headers = self.make_sync_call(
|
completion_stream, headers = self.make_sync_call(
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
|
@ -380,6 +386,7 @@ class BaseLLMHTTPHandler:
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_sync_call(
|
def make_sync_call(
|
||||||
|
@ -453,6 +460,7 @@ class BaseLLMHTTPHandler:
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if provider_config.has_custom_stream_wrapper is True:
|
if provider_config.has_custom_stream_wrapper is True:
|
||||||
return provider_config.get_async_custom_stream_wrapper(
|
return provider_config.get_async_custom_stream_wrapper(
|
||||||
|
@ -464,6 +472,7 @@ class BaseLLMHTTPHandler:
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
client=client,
|
client=client,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_stream, _response_headers = await self.make_async_call_stream_helper(
|
completion_stream, _response_headers = await self.make_async_call_stream_helper(
|
||||||
|
@ -720,7 +729,7 @@ class BaseLLMHTTPHandler:
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
# get config from model, custom llm provider
|
# get config from model, custom llm provider
|
||||||
headers = provider_config.validate_environment(
|
headers = provider_config.validate_environment(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -16,7 +16,10 @@ model_list:
|
||||||
api_key: os.environ/COHERE_API_KEY
|
api_key: os.environ/COHERE_API_KEY
|
||||||
- model_name: bedrock-claude-3-7
|
- model_name: bedrock-claude-3-7
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||||
|
- model_name: bedrock-claude-3-5-sonnet
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["langfuse"]
|
callbacks: ["langfuse"]
|
|
@ -3240,6 +3240,7 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
|
bedrock_base_model = BedrockModelInfo.get_base_model(model)
|
||||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3253,8 +3254,9 @@ def get_optional_params( # noqa: PLR0915
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "anthropic" in model and bedrock_route == "invoke":
|
elif "anthropic" in bedrock_base_model and bedrock_route == "invoke":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if bedrock_base_model.startswith("anthropic.claude-3"):
|
||||||
|
|
||||||
optional_params = (
|
optional_params = (
|
||||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
@ -3267,10 +3269,17 @@ def get_optional_params( # noqa: PLR0915
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif provider_config is not None:
|
elif provider_config is not None:
|
||||||
optional_params = provider_config.map_openai_params(
|
optional_params = provider_config.map_openai_params(
|
||||||
|
@ -6158,13 +6167,19 @@ class ProviderConfigManager:
|
||||||
elif litellm.LlmProviders.BEDROCK == provider:
|
elif litellm.LlmProviders.BEDROCK == provider:
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
||||||
model
|
model=model
|
||||||
)
|
)
|
||||||
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
|
|
||||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||||
return litellm.AmazonConverseConfig()
|
return litellm.AmazonConverseConfig()
|
||||||
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
||||||
return litellm.AmazonTitanConfig()
|
return litellm.AmazonTitanConfig()
|
||||||
|
elif bedrock_invoke_provider == "anthropic":
|
||||||
|
if base_model.startswith("anthropic.claude-3"):
|
||||||
|
return litellm.AmazonAnthropicClaude3Config()
|
||||||
|
else:
|
||||||
|
return litellm.AmazonAnthropicConfig()
|
||||||
elif (
|
elif (
|
||||||
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
||||||
): # amazon / meta llms
|
): # amazon / meta llms
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
|
ProviderConfigManager,
|
||||||
)
|
)
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
@ -247,12 +248,45 @@ class BaseLLMChatTest(ABC):
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(f"response={response}")
|
||||||
|
|
||||||
# OpenAI guarantees that the JSON schema is returned in the content
|
# OpenAI guarantees that the JSON schema is returned in the content
|
||||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
def test_response_format_type_text(self):
|
||||||
|
"""
|
||||||
|
Test that the response format type text does not lead to tool calls
|
||||||
|
"""
|
||||||
|
from litellm import LlmProviders
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
_, provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=base_completion_call_args["model"]
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
base_completion_call_args["model"], LlmProviders(provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"provider_config={provider_config}")
|
||||||
|
|
||||||
|
translated_params = provider_config.map_openai_params(
|
||||||
|
non_default_params={"response_format": {"type": "text"}},
|
||||||
|
optional_params={},
|
||||||
|
model=base_completion_call_args["model"],
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "tool_choice" not in translated_params
|
||||||
|
assert (
|
||||||
|
"tools" not in translated_params
|
||||||
|
), f"Got tools={translated_params['tools']}, expected no tools"
|
||||||
|
|
||||||
|
print(f"translated_params={translated_params}")
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=6, delay=1)
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
def test_json_response_pydantic_obj(self):
|
def test_json_response_pydantic_obj(self):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -22,17 +22,9 @@ class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def skip_non_json_tests(self, request):
|
|
||||||
if not "json" in request.function.__name__.lower():
|
|
||||||
pytest.skip(
|
|
||||||
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
|
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
|
||||||
def get_base_completion_call_args(self) -> dict:
|
def get_base_completion_call_args(self) -> dict:
|
||||||
litellm._turn_on_debug()
|
|
||||||
return {
|
return {
|
||||||
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
|
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
|
||||||
}
|
}
|
||||||
|
|
|
@ -319,12 +319,15 @@ def test_all_model_configs():
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"max_completion_tokens" in AmazonAnthropicConfig().get_supported_openai_params()
|
"max_completion_tokens"
|
||||||
|
in AmazonAnthropicConfig().get_supported_openai_params(model="")
|
||||||
)
|
)
|
||||||
|
|
||||||
assert AmazonAnthropicConfig().map_openai_params(
|
assert AmazonAnthropicConfig().map_openai_params(
|
||||||
non_default_params={"max_completion_tokens": 10},
|
non_default_params={"max_completion_tokens": 10},
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
model="",
|
||||||
|
drop_params=False,
|
||||||
) == {"max_tokens_to_sample": 10}
|
) == {"max_tokens_to_sample": 10}
|
||||||
|
|
||||||
from litellm.llms.databricks.chat.handler import DatabricksConfig
|
from litellm.llms.databricks.chat.handler import DatabricksConfig
|
||||||
|
|
|
@ -1114,3 +1114,257 @@ def test_anthropic_thinking_param(model, expected_thinking):
|
||||||
assert "thinking" in optional_params
|
assert "thinking" in optional_params
|
||||||
else:
|
else:
|
||||||
assert "thinking" not in optional_params
|
assert "thinking" not in optional_params
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_invoke_anthropic_max_tokens():
|
||||||
|
passed_params = {
|
||||||
|
"model": "invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"functions": None,
|
||||||
|
"function_call": None,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": None,
|
||||||
|
"n": 1,
|
||||||
|
"stream": False,
|
||||||
|
"stream_options": None,
|
||||||
|
"stop": None,
|
||||||
|
"max_tokens": None,
|
||||||
|
"max_completion_tokens": 1024,
|
||||||
|
"modalities": None,
|
||||||
|
"prediction": None,
|
||||||
|
"audio": None,
|
||||||
|
"presence_penalty": None,
|
||||||
|
"frequency_penalty": None,
|
||||||
|
"logit_bias": None,
|
||||||
|
"user": None,
|
||||||
|
"custom_llm_provider": "bedrock",
|
||||||
|
"response_format": {"type": "text"},
|
||||||
|
"seed": None,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "generate_plan",
|
||||||
|
"description": "Generate a plan to execute the task using only the tools outlined in your context.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"steps": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The type of step to execute",
|
||||||
|
},
|
||||||
|
"tool_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the tool to use for this step",
|
||||||
|
},
|
||||||
|
"tool_input": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The input to pass to the tool. Make sure this complies with the schema for the tool.",
|
||||||
|
},
|
||||||
|
"tool_output": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "(Optional) The output from the tool if needed for future steps. Make sure this complies with the schema for the tool.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["type"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "generate_wire_tool",
|
||||||
|
"description": "Create a wire transfer with complete wire instructions",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"company_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the company receiving the investment",
|
||||||
|
},
|
||||||
|
"investment_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the investment memo",
|
||||||
|
},
|
||||||
|
"dollar_amount": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The amount to wire in USD",
|
||||||
|
},
|
||||||
|
"wiring_instructions": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Complete bank account and routing information for the wire",
|
||||||
|
"properties": {
|
||||||
|
"account_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name on the bank account",
|
||||||
|
},
|
||||||
|
"address_1": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Primary address line",
|
||||||
|
},
|
||||||
|
"address_2": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Secondary address line (optional)",
|
||||||
|
},
|
||||||
|
"city": {"type": "string"},
|
||||||
|
"state": {"type": "string"},
|
||||||
|
"zip": {"type": "string"},
|
||||||
|
"country": {"type": "string", "default": "US"},
|
||||||
|
"bank_name": {"type": "string"},
|
||||||
|
"account_number": {"type": "string"},
|
||||||
|
"routing_number": {"type": "string"},
|
||||||
|
"account_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["checking", "savings"],
|
||||||
|
"default": "checking",
|
||||||
|
},
|
||||||
|
"swift_code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Required for international wires",
|
||||||
|
},
|
||||||
|
"iban": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Required for some international wires",
|
||||||
|
},
|
||||||
|
"bank_city": {"type": "string"},
|
||||||
|
"bank_state": {"type": "string"},
|
||||||
|
"bank_country": {"type": "string", "default": "US"},
|
||||||
|
"bank_to_bank_instructions": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Additional instructions for the bank (optional)",
|
||||||
|
},
|
||||||
|
"intermediary_bank_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name of intermediary bank if required (optional)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"account_name",
|
||||||
|
"address_1",
|
||||||
|
"country",
|
||||||
|
"bank_name",
|
||||||
|
"account_number",
|
||||||
|
"routing_number",
|
||||||
|
"account_type",
|
||||||
|
"bank_country",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"company_id",
|
||||||
|
"investment_id",
|
||||||
|
"dollar_amount",
|
||||||
|
"wiring_instructions",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_companies",
|
||||||
|
"description": "Search for companies by name or other criteria to get their IDs",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name or part of name to search for",
|
||||||
|
},
|
||||||
|
"batch": {
|
||||||
|
"type": "string",
|
||||||
|
"description": 'Optional batch filter (e.g., "W21", "S22")',
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"live",
|
||||||
|
"dead",
|
||||||
|
"adrift",
|
||||||
|
"exited",
|
||||||
|
"went_public",
|
||||||
|
"all",
|
||||||
|
],
|
||||||
|
"description": "Filter by company status",
|
||||||
|
"default": "live",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return",
|
||||||
|
"default": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
"output_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Success or error status",
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "List of companies matching the search criteria",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Company ID to use in other API calls",
|
||||||
|
},
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"batch": {"type": "string"},
|
||||||
|
"status": {"type": "string"},
|
||||||
|
"valuation": {"type": "string"},
|
||||||
|
"url": {"type": "string"},
|
||||||
|
"description": {"type": "string"},
|
||||||
|
"founders": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"results_count": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of companies returned",
|
||||||
|
},
|
||||||
|
"total_matches": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Total number of matches found",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tool_choice": None,
|
||||||
|
"max_retries": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"top_logprobs": None,
|
||||||
|
"extra_headers": None,
|
||||||
|
"api_version": None,
|
||||||
|
"parallel_tool_calls": None,
|
||||||
|
"drop_params": True,
|
||||||
|
"reasoning_effort": None,
|
||||||
|
"additional_drop_params": None,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an AI assistant that helps prepare a wire for a pro rata investment.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "hi"}]},
|
||||||
|
],
|
||||||
|
"thinking": None,
|
||||||
|
"kwargs": {},
|
||||||
|
}
|
||||||
|
optional_params = get_optional_params(**passed_params)
|
||||||
|
print(f"optional_params: {optional_params}")
|
||||||
|
|
||||||
|
assert "max_tokens_to_sample" not in optional_params
|
||||||
|
assert optional_params["max_tokens"] == 1024
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue