LiteLLM Minor Fixes & Improvements (10/28/2024) (#6475)

* fix(anthropic/chat/transformation.py): support anthropic disable_parallel_tool_use param

Fixes https://github.com/BerriAI/litellm/issues/6456

* feat(anthropic/chat/transformation.py): support anthropic computer tool use

Closes https://github.com/BerriAI/litellm/issues/6427

* fix(vertex_ai/common_utils.py): parse out '$schema' when calling vertex ai

Fixes issue when trying to call vertex from vercel sdk

* fix(main.py): add 'extra_headers' support for azure on all translation endpoints

Fixes https://github.com/BerriAI/litellm/issues/6465

* fix: fix linting errors

* fix(transformation.py): handle no beta headers for anthropic

* test: cleanup test

* fix: fix linting error

* fix: fix linting errors

* fix: fix linting errors

* fix(transformation.py): handle dummy tool call

* fix(main.py): fix linting error

* fix(azure.py): pass required param

* LiteLLM Minor Fixes & Improvements (10/24/2024) (#6441)

* fix(azure.py): handle /openai/deployment in azure api base

* fix(factory.py): fix faulty anthropic tool result translation check

Fixes https://github.com/BerriAI/litellm/issues/6422

* fix(gpt_transformation.py): add support for parallel_tool_calls to azure

Fixes https://github.com/BerriAI/litellm/issues/6440

* fix(factory.py): support anthropic prompt caching for tool results

* fix(vertex_ai/common_utils): don't pop non-null required field

Fixes https://github.com/BerriAI/litellm/issues/6426

* feat(vertex_ai.py): support code_execution tool call for vertex ai + gemini

Closes https://github.com/BerriAI/litellm/issues/6434

* build(model_prices_and_context_window.json): Add 'supports_assistant_prefill' for bedrock claude-3-5-sonnet v2 models

Closes https://github.com/BerriAI/litellm/issues/6437

* fix(types/utils.py): fix linting

* test: update test to include required fields

* test: fix test

* test: handle flaky test

* test: remove e2e test - hitting gemini rate limits

* Litellm dev 10 26 2024 (#6472)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* (Testing) Add unit testing for DualCache - ensure in memory cache is used when expected  (#6471)

* test test_dual_cache_get_set

* unit testing for dual cache

* fix async_set_cache_sadd

* test_dual_cache_local_only

* redis otel tracing + async support for latency routing (#6452)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param

* fix(dual_cache.py): set default value for parent_otel_span

* fix(transformation.py): support 'response_format' for anthropic calls

* fix(transformation.py): check for cache_control inside 'function' block

* fix: fix linting error

* fix: fix linting errors

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Krish Dholakia 2024-10-29 17:20:24 -07:00 committed by GitHub
parent 44e7ffd05c
commit 6b9be5092f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 684 additions and 253 deletions

View file

@ -721,6 +721,37 @@ except Exception as e:
s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this! s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this!
### Computer Tools
```python
from litellm import completion
tools = [
{
"type": "computer_20241022",
"function": {
"name": "computer",
"parameters": {
"display_height_px": 100,
"display_width_px": 100,
"display_number": 1,
},
},
}
]
model = "claude-3-5-sonnet-20241022"
messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}]
resp = completion(
model=model,
messages=messages,
tools=tools,
# headers={"anthropic-beta": "computer-use-2024-10-22"},
)
print(resp)
```
## Usage - Vision ## Usage - Vision
```python ```python

View file

@ -152,7 +152,7 @@ class DualCache(BaseCache):
def batch_get_cache( def batch_get_cache(
self, self,
keys: list, keys: list,
parent_otel_span: Optional[Span], parent_otel_span: Optional[Span] = None,
local_only: bool = False, local_only: bool = False,
**kwargs, **kwargs,
): ):
@ -343,7 +343,7 @@ class DualCache(BaseCache):
self, self,
key, key,
value: float, value: float,
parent_otel_span: Optional[Span], parent_otel_span: Optional[Span] = None,
local_only: bool = False, local_only: bool = False,
**kwargs, **kwargs,
) -> float: ) -> float:

View file

@ -961,6 +961,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
api_key: str, api_key: str,
data: dict, data: dict,
headers: dict,
) -> httpx.Response: ) -> httpx.Response:
""" """
Implemented for azure dall-e-2 image gen calls Implemented for azure dall-e-2 image gen calls
@ -1002,10 +1003,7 @@ class AzureChatCompletion(BaseLLM):
response = await async_handler.post( response = await async_handler.post(
url=api_base, url=api_base,
data=json.dumps(data), data=json.dumps(data),
headers={ headers=headers,
"Content-Type": "application/json",
"api-key": api_key,
},
) )
if "operation-location" in response.headers: if "operation-location" in response.headers:
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
@ -1013,9 +1011,7 @@ class AzureChatCompletion(BaseLLM):
raise AzureOpenAIError(status_code=500, message=response.text) raise AzureOpenAIError(status_code=500, message=response.text)
response = await async_handler.get( response = await async_handler.get(
url=operation_location_url, url=operation_location_url,
headers={ headers=headers,
"api-key": api_key,
},
) )
await response.aread() await response.aread()
@ -1036,9 +1032,7 @@ class AzureChatCompletion(BaseLLM):
await asyncio.sleep(int(response.headers.get("retry-after") or 10)) await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await async_handler.get( response = await async_handler.get(
url=operation_location_url, url=operation_location_url,
headers={ headers=headers,
"api-key": api_key,
},
) )
await response.aread() await response.aread()
@ -1056,10 +1050,7 @@ class AzureChatCompletion(BaseLLM):
return await async_handler.post( return await async_handler.post(
url=api_base, url=api_base,
json=data, json=data,
headers={ headers=headers,
"Content-Type": "application/json;",
"api-key": api_key,
},
) )
def make_sync_azure_httpx_request( def make_sync_azure_httpx_request(
@ -1070,6 +1061,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
api_key: str, api_key: str,
data: dict, data: dict,
headers: dict,
) -> httpx.Response: ) -> httpx.Response:
""" """
Implemented for azure dall-e-2 image gen calls Implemented for azure dall-e-2 image gen calls
@ -1085,7 +1077,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler = HTTPHandler(**_params) # type: ignore sync_handler = HTTPHandler(**_params, client=litellm.client_session) # type: ignore
else: else:
sync_handler = client # type: ignore sync_handler = client # type: ignore
@ -1111,10 +1103,7 @@ class AzureChatCompletion(BaseLLM):
response = sync_handler.post( response = sync_handler.post(
url=api_base, url=api_base,
data=json.dumps(data), data=json.dumps(data),
headers={ headers=headers,
"Content-Type": "application/json",
"api-key": api_key,
},
) )
if "operation-location" in response.headers: if "operation-location" in response.headers:
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
@ -1122,9 +1111,7 @@ class AzureChatCompletion(BaseLLM):
raise AzureOpenAIError(status_code=500, message=response.text) raise AzureOpenAIError(status_code=500, message=response.text)
response = sync_handler.get( response = sync_handler.get(
url=operation_location_url, url=operation_location_url,
headers={ headers=headers,
"api-key": api_key,
},
) )
response.read() response.read()
@ -1144,9 +1131,7 @@ class AzureChatCompletion(BaseLLM):
time.sleep(int(response.headers.get("retry-after") or 10)) time.sleep(int(response.headers.get("retry-after") or 10))
response = sync_handler.get( response = sync_handler.get(
url=operation_location_url, url=operation_location_url,
headers={ headers=headers,
"api-key": api_key,
},
) )
response.read() response.read()
@ -1164,10 +1149,7 @@ class AzureChatCompletion(BaseLLM):
return sync_handler.post( return sync_handler.post(
url=api_base, url=api_base,
json=data, json=data,
headers={ headers=headers,
"Content-Type": "application/json;",
"api-key": api_key,
},
) )
def create_azure_base_url( def create_azure_base_url(
@ -1200,6 +1182,7 @@ class AzureChatCompletion(BaseLLM):
api_key: str, api_key: str,
input: list, input: list,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
headers: dict,
client=None, client=None,
timeout=None, timeout=None,
) -> litellm.ImageResponse: ) -> litellm.ImageResponse:
@ -1223,7 +1206,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": img_gen_api_base, "api_base": img_gen_api_base,
"headers": {"api_key": api_key}, "headers": headers,
}, },
) )
httpx_response: httpx.Response = await self.make_async_azure_httpx_request( httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
@ -1233,6 +1216,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version, api_version=api_version,
api_key=api_key, api_key=api_key,
data=data, data=data,
headers=headers,
) )
response = httpx_response.json() response = httpx_response.json()
@ -1265,6 +1249,7 @@ class AzureChatCompletion(BaseLLM):
timeout: float, timeout: float,
optional_params: dict, optional_params: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
headers: dict,
model: Optional[str] = None, model: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
@ -1315,7 +1300,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation is True: if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
img_gen_api_base = self.create_azure_base_url( img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "") azure_client_params=azure_client_params, model=data.get("model", "")
@ -1328,7 +1313,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": img_gen_api_base, "api_base": img_gen_api_base,
"headers": {"api_key": api_key}, "headers": headers,
}, },
) )
httpx_response: httpx.Response = self.make_sync_azure_httpx_request( httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
@ -1338,6 +1323,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version or "", api_version=api_version or "",
api_key=api_key or "", api_key=api_key or "",
data=data, data=data,
headers=headers,
) )
response = httpx_response.json() response = httpx_response.json()

View file

@ -29,6 +29,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthropicChatCompletionUsageBlock, AnthropicChatCompletionUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
@ -53,9 +54,14 @@ from .transformation import AnthropicConfig
# makes headers for API call # makes headers for API call
def validate_environment( def validate_environment(
api_key, user_headers, model, messages: List[AllMessageValues] api_key,
user_headers,
model,
messages: List[AllMessageValues],
tools: Optional[List[AllAnthropicToolsValues]],
anthropic_version: Optional[str] = None,
): ):
cache_headers = {}
if api_key is None: if api_key is None:
raise litellm.AuthenticationError( raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars", message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
@ -63,17 +69,15 @@ def validate_environment(
model=model, model=model,
) )
if AnthropicConfig().is_cache_control_set(messages=messages): prompt_caching_set = AnthropicConfig().is_cache_control_set(messages=messages)
cache_headers = AnthropicConfig().get_cache_control_headers() computer_tool_used = AnthropicConfig().is_computer_tool_used(tools=tools)
headers = { headers = AnthropicConfig().get_anthropic_headers(
"accept": "application/json", anthropic_version=anthropic_version,
"anthropic-version": "2023-06-01", computer_tool_used=computer_tool_used,
"content-type": "application/json", prompt_caching_set=prompt_caching_set,
"x-api-key": api_key, api_key=api_key,
} )
headers.update(cache_headers)
if user_headers is not None and isinstance(user_headers, dict): if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers} headers = {**headers, **user_headers}
@ -441,7 +445,13 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
client=None, client=None,
): ):
headers = validate_environment(api_key, headers, model, messages=messages) headers = validate_environment(
api_key,
headers,
model,
messages=messages,
tools=optional_params.get("tools"),
)
_is_function_call = False _is_function_call = False
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params) optional_params = copy.deepcopy(optional_params)

View file

@ -4,6 +4,9 @@ from typing import List, Literal, Optional, Tuple, Union
import litellm import litellm
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthropicComputerTool,
AnthropicHostedTools,
AnthropicMessageRequestBase, AnthropicMessageRequestBase,
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicMessagesTool, AnthropicMessagesTool,
@ -12,6 +15,7 @@ from litellm.types.llms.anthropic import (
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionCachedContent,
ChatCompletionSystemMessage, ChatCompletionSystemMessage,
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
@ -84,6 +88,8 @@ class AnthropicConfig:
"tools", "tools",
"tool_choice", "tool_choice",
"extra_headers", "extra_headers",
"parallel_tool_calls",
"response_format",
] ]
def get_cache_control_headers(self) -> dict: def get_cache_control_headers(self) -> dict:
@ -92,6 +98,146 @@ class AnthropicConfig:
"anthropic-beta": "prompt-caching-2024-07-31", "anthropic-beta": "prompt-caching-2024-07-31",
} }
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
) -> dict:
import json
betas = []
if prompt_caching_set:
betas.append("prompt-caching-2024-07-31")
if computer_tool_used:
betas.append("computer-use-2024-10-22")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
if len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def _map_tool_choice(
self, tool_choice: Optional[str], disable_parallel_tool_use: Optional[bool]
) -> Optional[AnthropicMessagesToolChoice]:
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if tool_choice == "auto":
_tool_choice = AnthropicMessagesToolChoice(
type="auto",
)
elif tool_choice == "required":
_tool_choice = AnthropicMessagesToolChoice(type="any")
elif isinstance(tool_choice, dict):
_tool_name = tool_choice.get("function", {}).get("name")
_tool_choice = AnthropicMessagesToolChoice(type="tool")
if _tool_name is not None:
_tool_choice["name"] = _tool_name
if disable_parallel_tool_use is not None:
if _tool_choice is not None:
_tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use
else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag
_tool_choice = AnthropicMessagesToolChoice(
type="auto",
disable_parallel_tool_use=disable_parallel_tool_use,
)
return _tool_choice
def _map_tool_helper(
self, tool: ChatCompletionToolParam
) -> AllAnthropicToolsValues:
returned_tool: Optional[AllAnthropicToolsValues] = None
if tool["type"] == "function" or tool["type"] == "custom":
_tool = AnthropicMessagesTool(
name=tool["function"]["name"],
input_schema=tool["function"].get(
"parameters",
{
"type": "object",
"properties": {},
},
),
)
_description = tool["function"].get("description")
if _description is not None:
_tool["description"] = _description
returned_tool = _tool
elif tool["type"].startswith("computer_"):
## check if all required 'display_' params are given
if "parameters" not in tool["function"]:
raise ValueError("Missing required parameter: parameters")
_display_width_px: Optional[int] = tool["function"]["parameters"].get(
"display_width_px"
)
_display_height_px: Optional[int] = tool["function"]["parameters"].get(
"display_height_px"
)
if _display_width_px is None or _display_height_px is None:
raise ValueError(
"Missing required parameter: display_width_px or display_height_px"
)
_computer_tool = AnthropicComputerTool(
type=tool["type"],
name=tool["function"].get("name", "computer"),
display_width_px=_display_width_px,
display_height_px=_display_height_px,
)
_display_number = tool["function"]["parameters"].get("display_number")
if _display_number is not None:
_computer_tool["display_number"] = _display_number
returned_tool = _computer_tool
elif tool["type"].startswith("bash_") or tool["type"].startswith(
"text_editor_"
):
function_name = tool["function"].get("name")
if function_name is None:
raise ValueError("Missing required parameter: name")
returned_tool = AnthropicHostedTools(
type=tool["type"],
name=function_name,
)
if returned_tool is None:
raise ValueError(f"Unsupported tool type: {tool['type']}")
## check if cache_control is set in the tool
_cache_control = tool.get("cache_control", None)
_cache_control_function = tool.get("function", {}).get("cache_control", None)
if _cache_control is not None:
returned_tool["cache_control"] = _cache_control
elif _cache_control_function is not None and isinstance(
_cache_control_function, dict
):
returned_tool["cache_control"] = ChatCompletionCachedContent(
**_cache_control_function # type: ignore
)
return returned_tool
def _map_tools(self, tools: List) -> List[AllAnthropicToolsValues]:
anthropic_tools = []
for tool in tools:
if "input_schema" in tool: # assume in anthropic format
anthropic_tools.append(tool)
else: # assume openai tool call
new_tool = self._map_tool_helper(tool)
anthropic_tools.append(new_tool)
return anthropic_tools
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -104,15 +250,16 @@ class AnthropicConfig:
if param == "max_completion_tokens": if param == "max_completion_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "tools": if param == "tools":
optional_params["tools"] = value optional_params["tools"] = self._map_tools(value)
if param == "tool_choice": if param == "tool_choice" or param == "parallel_tool_calls":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None _tool_choice: Optional[AnthropicMessagesToolChoice] = (
if value == "auto": self._map_tool_choice(
_tool_choice = {"type": "auto"} tool_choice=non_default_params.get("tool_choice"),
elif value == "required": disable_parallel_tool_use=non_default_params.get(
_tool_choice = {"type": "any"} "parallel_tool_calls"
elif isinstance(value, dict): ),
_tool_choice = {"type": "tool", "name": value["function"]["name"]} )
)
if _tool_choice is not None: if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
@ -142,6 +289,32 @@ class AnthropicConfig:
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
## VALIDATE REQUEST ## VALIDATE REQUEST
""" """
@ -153,8 +326,8 @@ class AnthropicConfig:
and has_tool_call_blocks(messages) and has_tool_call_blocks(messages)
): ):
if litellm.modify_params: if litellm.modify_params:
optional_params["tools"] = add_dummy_tool( optional_params["tools"] = self._map_tools(
custom_llm_provider="bedrock_converse" add_dummy_tool(custom_llm_provider="anthropic")
) )
else: else:
raise litellm.UnsupportedParamsError( raise litellm.UnsupportedParamsError(
@ -182,6 +355,16 @@ class AnthropicConfig:
return False return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return True
return False
def translate_system_message( def translate_system_message(
self, messages: List[AllMessageValues] self, messages: List[AllMessageValues]
) -> List[AnthropicSystemMessageContent]: ) -> List[AnthropicSystemMessageContent]:
@ -276,24 +459,6 @@ class AnthropicConfig:
## Handle Tool Calling ## Handle Tool Calling
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
anthropic_tools = []
for tool in optional_params["tools"]:
if "input_schema" in tool: # assume in anthropic format
anthropic_tools.append(tool)
else: # assume openai tool call
new_tool = tool["function"]
parameters = new_tool.pop(
"parameters",
{
"type": "object",
"properties": {},
},
)
new_tool["input_schema"] = parameters # rename key
if "cache_control" in tool:
new_tool["cache_control"] = tool["cache_control"]
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
data = { data = {
"messages": anthropic_messages, "messages": anthropic_messages,

View file

@ -6,9 +6,12 @@ from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingCho
import litellm import litellm
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthopicMessagesAssistantMessageParam, AnthopicMessagesAssistantMessageParam,
AnthropicChatCompletionUsageBlock, AnthropicChatCompletionUsageBlock,
AnthropicComputerTool,
AnthropicFinishReason, AnthropicFinishReason,
AnthropicHostedTools,
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicMessagesTool, AnthropicMessagesTool,
AnthropicMessagesToolChoice, AnthropicMessagesToolChoice,
@ -215,16 +218,22 @@ class AnthropicExperimentalPassThroughConfig:
) )
def translate_anthropic_tools_to_openai( def translate_anthropic_tools_to_openai(
self, tools: List[AnthropicMessagesTool] self, tools: List[AllAnthropicToolsValues]
) -> List[ChatCompletionToolParam]: ) -> List[ChatCompletionToolParam]:
new_tools: List[ChatCompletionToolParam] = [] new_tools: List[ChatCompletionToolParam] = []
mapped_tool_params = ["name", "input_schema", "description"]
for tool in tools: for tool in tools:
function_chunk = ChatCompletionToolParamFunctionChunk( function_chunk = ChatCompletionToolParamFunctionChunk(
name=tool["name"], name=tool["name"],
parameters=tool["input_schema"],
) )
if "input_schema" in tool:
function_chunk["parameters"] = tool["input_schema"] # type: ignore
if "description" in tool: if "description" in tool:
function_chunk["description"] = tool["description"] function_chunk["description"] = tool["description"] # type: ignore
for k, v in tool.items():
if k not in mapped_tool_params: # pass additional computer kwargs
function_chunk.setdefault("parameters", {}).update({k: v})
new_tools.append( new_tools.append(
ChatCompletionToolParam(type="function", function=function_chunk) ChatCompletionToolParam(type="function", function=function_chunk)
) )

View file

@ -164,7 +164,12 @@ def _build_vertex_schema(parameters: dict):
# 4. Suppress unnecessary title generation: # 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051 # * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780 # * http://cl/586221780
strip_titles(parameters) strip_field(parameters, field_name="title")
strip_field(
parameters, field_name="$schema"
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
return parameters return parameters
@ -245,14 +250,14 @@ def add_object_type(schema):
add_object_type(items) add_object_type(items)
def strip_titles(schema): def strip_field(schema, field_name: str):
schema.pop("title", None) schema.pop(field_name, None)
properties = schema.get("properties", None) properties = schema.get("properties", None)
if properties is not None: if properties is not None:
for name, value in properties.items(): for name, value in properties.items():
strip_titles(value) strip_field(value, field_name)
items = schema.get("items", None) items = schema.get("items", None)
if items is not None: if items is not None:
strip_titles(items) strip_field(items, field_name)

View file

@ -400,14 +400,26 @@ class VertexGeminiConfig:
value = _remove_additional_properties(value) value = _remove_additional_properties(value)
# remove 'strict' from tools # remove 'strict' from tools
value = _remove_strict_from_schema(value) value = _remove_strict_from_schema(value)
for tool in value: for tool in value:
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
None None
) )
if "function" in tool: # tools list if "function" in tool: # tools list
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
**tool["function"] **tool["function"]
) )
if (
"parameters" in _openai_function_object
and _openai_function_object["parameters"] is not None
): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema.
_openai_function_object["parameters"] = _build_vertex_schema(
_openai_function_object["parameters"]
)
openai_function_object = _openai_function_object
elif "name" in tool: # functions list elif "name" in tool: # functions list
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore

View file

@ -15,10 +15,6 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
@ -26,6 +22,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ResponseFormatChunk from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ....anthropic.chat.transformation import AnthropicConfig
from ....prompt_templates.factory import ( from ....prompt_templates.factory import (
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
contains_tag, contains_tag,
@ -50,7 +47,7 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig: class VertexAIAnthropicConfig(AnthropicConfig):
""" """
Reference:https://docs.anthropic.com/claude/reference/messages_post Reference:https://docs.anthropic.com/claude/reference/messages_post
@ -72,112 +69,6 @@ class VertexAIAnthropicConfig:
Note: Please make sure to modify the default parameters as required for your use case. Note: Please make sure to modify the default parameters as required for your use case.
""" """
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params
@classmethod @classmethod
def is_supported_model( def is_supported_model(
cls, model: str, custom_llm_provider: Optional[str] = None cls, model: str, custom_llm_provider: Optional[str] = None

View file

@ -3377,6 +3377,9 @@ def embedding( # noqa: PLR0915
"azure_ad_token", None "azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN") ) or get_secret_str("AZURE_AD_TOKEN")
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
api_key = ( api_key = (
api_key api_key
or litellm.api_key or litellm.api_key
@ -4458,7 +4461,10 @@ def image_generation( # noqa: PLR0915
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
client = kwargs.get("client", None) client = kwargs.get("client", None)
extra_headers = kwargs.get("extra_headers", None)
headers: dict = kwargs.get("headers", None) or {}
if extra_headers is not None:
headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse() model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None: if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
@ -4589,6 +4595,14 @@ def image_generation( # noqa: PLR0915
"azure_ad_token", None "azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN") ) or get_secret_str("AZURE_AD_TOKEN")
default_headers = {
"Content-Type": "application/json;",
"api-key": api_key,
}
for k, v in default_headers.items():
if k not in headers:
headers[k] = v
model_response = azure_chat_completions.image_generation( model_response = azure_chat_completions.image_generation(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -4601,6 +4615,7 @@ def image_generation( # noqa: PLR0915
api_version=api_version, api_version=api_version,
aimg_generation=aimg_generation, aimg_generation=aimg_generation,
client=client, client=client,
headers=headers,
) )
elif custom_llm_provider == "openai": elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation( model_response = openai_chat_completions.image_generation(
@ -4797,11 +4812,7 @@ def transcription(
""" """
atranscription = kwargs.get("atranscription", False) atranscription = kwargs.get("atranscription", False)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
kwargs.get("litellm_call_id", None) extra_headers = kwargs.get("extra_headers", None)
kwargs.get("logger_fn", None)
kwargs.get("proxy_server_request", None)
kwargs.get("model_info", None)
kwargs.get("metadata", {})
kwargs.pop("tags", []) kwargs.pop("tags", [])
drop_params = kwargs.get("drop_params", None) drop_params = kwargs.get("drop_params", None)
@ -4857,6 +4868,8 @@ def transcription(
or get_secret_str("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) )
optional_params["extra_headers"] = extra_headers
response = azure_audio_transcriptions.audio_transcriptions( response = azure_audio_transcriptions.audio_transcriptions(
model=model, model=model,
audio_file=file, audio_file=file,
@ -4975,6 +4988,7 @@ def speech(
user = kwargs.get("user", None) user = kwargs.get("user", None)
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
extra_headers = kwargs.get("extra_headers", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
kwargs.pop("tags", []) kwargs.pop("tags", [])
@ -5087,7 +5101,8 @@ def speech(
"AZURE_AD_TOKEN" "AZURE_AD_TOKEN"
) )
headers = headers or litellm.headers if extra_headers:
optional_params["extra_headers"] = extra_headers
response = azure_chat_completions.audio_speech( response = azure_chat_completions.audio_speech(
model=model, model=model,

View file

@ -774,6 +774,20 @@
"supports_vision": true, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"azure/gpt-4o-mini-2024-07-18": {
"max_tokens": 16384,
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 0.000000165,
"output_cost_per_token": 0.00000066,
"cache_read_input_token_cost": 0.000000075,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_prompt_caching": true
},
"azure/gpt-4-turbo-2024-04-09": { "azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -81,7 +81,10 @@ from litellm.router_utils.fallback_event_handlers import (
run_async_fallback, run_async_fallback,
run_sync_fallback, run_sync_fallback,
) )
from litellm.router_utils.handle_error import send_llm_exception_alert from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception,
send_llm_exception_alert,
)
from litellm.router_utils.router_callbacks.track_deployment_metrics import ( from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_failures_for_current_minute, increment_deployment_failures_for_current_minute,
increment_deployment_successes_for_current_minute, increment_deployment_successes_for_current_minute,
@ -5183,21 +5186,12 @@ class Router:
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None: exception = await async_raise_no_deployment_exception(
_allowed_model_region = "n/a" litellm_router_instance=self,
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, parent_otel_span=parent_otel_span,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
raise exception
start_time = time.time() start_time = time.time()
if ( if (
self.routing_strategy == "usage-based-routing-v2" self.routing_strategy == "usage-based-routing-v2"
@ -5255,22 +5249,12 @@ class Router:
else: else:
deployment = None deployment = None
if deployment is None: if deployment is None:
verbose_router_logger.info( exception = await async_raise_no_deployment_exception(
f"get_available_deployment for model: {model}, No deployment available" litellm_router_instance=self,
)
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, parent_otel_span=parent_otel_span,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
) )
raise exception
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
) )

View file

@ -17,13 +17,6 @@ if TYPE_CHECKING:
else: else:
Span = Any Span = Any
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CooldownCacheValue(TypedDict): class CooldownCacheValue(TypedDict):
exception_received: str exception_received: str
@ -117,7 +110,6 @@ class CooldownCache:
if results is None: if results is None:
return active_cooldowns return active_cooldowns
# Process the results # Process the results
for model_id, result in zip(model_ids, results): for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict): if result and isinstance(result, dict):

View file

@ -1,15 +1,22 @@
import asyncio import asyncio
import traceback import traceback
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_router_logger
from litellm.router_utils.cooldown_handlers import _async_get_cooldown_deployments
from litellm.types.integrations.slack_alerting import AlertType from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterRateLimitError
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router from litellm.router import Router as _Router
LitellmRouter = _Router LitellmRouter = _Router
Span = _Span
else: else:
LitellmRouter = Any LitellmRouter = Any
Span = Any
async def send_llm_exception_alert( async def send_llm_exception_alert(
@ -55,3 +62,28 @@ async def send_llm_exception_alert(
alert_type=AlertType.llm_exceptions, alert_type=AlertType.llm_exceptions,
alerting_metadata={}, alerting_metadata={},
) )
async def async_raise_no_deployment_exception(
litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span]
):
"""
Raises a RouterRateLimitError if no deployment is found for the given model.
"""
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
model_ids = litellm_router_instance.get_model_ids(model_name=model)
_cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=litellm_router_instance,
parent_otel_span=parent_otel_span,
)
return RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)

View file

@ -9,12 +9,35 @@ from .openai import ChatCompletionCachedContent
class AnthropicMessagesToolChoice(TypedDict, total=False): class AnthropicMessagesToolChoice(TypedDict, total=False):
type: Required[Literal["auto", "any", "tool"]] type: Required[Literal["auto", "any", "tool"]]
name: str name: str
disable_parallel_tool_use: bool # default is false
class AnthropicMessagesTool(TypedDict, total=False): class AnthropicMessagesTool(TypedDict, total=False):
name: Required[str] name: Required[str]
description: str description: str
input_schema: Required[dict] input_schema: Required[dict]
type: Literal["custom"]
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
class AnthropicComputerTool(TypedDict, total=False):
display_width_px: Required[int]
display_height_px: Required[int]
display_number: int
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
type: Required[str]
name: Required[str]
class AnthropicHostedTools(TypedDict, total=False): # for bash_tool and text_editor
type: Required[str]
name: Required[str]
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
AllAnthropicToolsValues = Union[
AnthropicComputerTool, AnthropicHostedTools, AnthropicMessagesTool
]
class AnthropicMessagesTextParam(TypedDict, total=False): class AnthropicMessagesTextParam(TypedDict, total=False):
@ -117,7 +140,7 @@ class AnthropicMessageRequestBase(TypedDict, total=False):
system: Union[str, List] system: Union[str, List]
temperature: float temperature: float
tool_choice: AnthropicMessagesToolChoice tool_choice: AnthropicMessagesToolChoice
tools: List[AnthropicMessagesTool] tools: List[AllAnthropicToolsValues]
top_k: int top_k: int
top_p: float top_p: float

View file

@ -440,11 +440,15 @@ class ChatCompletionToolParamFunctionChunk(TypedDict, total=False):
parameters: dict parameters: dict
class ChatCompletionToolParam(TypedDict): class OpenAIChatCompletionToolParam(TypedDict):
type: Literal["function"] type: Union[Literal["function"], str]
function: ChatCompletionToolParamFunctionChunk function: ChatCompletionToolParamFunctionChunk
class ChatCompletionToolParam(OpenAIChatCompletionToolParam, total=False):
cache_control: ChatCompletionCachedContent
class Function(TypedDict, total=False): class Function(TypedDict, total=False):
name: Required[str] name: Required[str]
"""The name of the function to call.""" """The name of the function to call."""

View file

@ -527,3 +527,98 @@ def test_process_anthropic_headers_with_no_matching_headers():
result = process_anthropic_headers(input_headers) result = process_anthropic_headers(input_headers)
assert result == expected_output, "Unexpected output for non-matching headers" assert result == expected_output, "Unexpected output for non-matching headers"
def test_anthropic_computer_tool_use():
from litellm import completion
tools = [
{
"type": "computer_20241022",
"function": {
"name": "computer",
"parameters": {
"display_height_px": 100,
"display_width_px": 100,
"display_number": 1,
},
},
}
]
model = "claude-3-5-sonnet-20241022"
messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}]
resp = completion(
model=model,
messages=messages,
tools=tools,
# headers={"anthropic-beta": "computer-use-2024-10-22"},
)
print(resp)
@pytest.mark.parametrize(
"computer_tool_used, prompt_caching_set, expected_beta_header",
[
(True, False, True),
(False, True, True),
(True, True, True),
(False, False, False),
],
)
def test_anthropic_beta_header(
computer_tool_used, prompt_caching_set, expected_beta_header
):
headers = litellm.AnthropicConfig().get_anthropic_headers(
api_key="fake-api-key",
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
)
if expected_beta_header:
assert "anthropic-beta" in headers
else:
assert "anthropic-beta" not in headers
@pytest.mark.parametrize(
"cache_control_location",
[
"inside_function",
"outside_function",
],
)
def test_anthropic_tool_helper(cache_control_location):
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
tool = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
if cache_control_location == "inside_function":
tool["function"]["cache_control"] = {"type": "ephemeral"}
else:
tool["cache_control"] = {"type": "ephemeral"}
tool = AnthropicConfig()._map_tool_helper(tool=tool)
assert tool["cache_control"] == {"type": "ephemeral"}

View file

@ -96,6 +96,66 @@ def test_process_azure_headers_with_dict_input():
assert result == expected_output, "Unexpected output for dict input" assert result == expected_output, "Unexpected output for dict input"
from httpx import Client
from unittest.mock import MagicMock, patch
from openai import AzureOpenAI
import litellm
from litellm import completion
import os
@pytest.mark.parametrize(
"input, call_type",
[
({"messages": [{"role": "user", "content": "Hello world"}]}, "completion"),
({"input": "Hello world"}, "embedding"),
({"prompt": "Hello world"}, "image_generation"),
],
)
def test_azure_extra_headers(input, call_type):
from litellm import embedding, image_generation
http_client = Client()
messages = [{"role": "user", "content": "Hello world"}]
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
litellm.client_session = http_client
try:
if call_type == "completion":
func = completion
elif call_type == "embedding":
func = embedding
elif call_type == "image_generation":
func = image_generation
response = func(
model="azure/chatgpt-v-2",
api_base="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2023-07-01-preview",
api_key="my-azure-api-key",
extra_headers={
"Authorization": "my-bad-key",
"Ocp-Apim-Subscription-Key": "hello-world-testing",
},
**input,
)
print(response)
except Exception as e:
print(e)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
request = mock_client.call_args[0][0]
print(request.method) # This will print 'POST'
print(request.url) # This will print the full URL
print(request.headers) # This will print the full URL
auth_header = request.headers.get("Authorization")
apim_key = request.headers.get("Ocp-Apim-Subscription-Key")
print(auth_header)
assert auth_header == "my-bad-key"
assert apim_key == "hello-world-testing"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"api_base, model, expected_endpoint", "api_base, model, expected_endpoint",
[ [

View file

@ -786,19 +786,122 @@ def test_unmapped_vertex_anthropic_model():
assert "max_retries" not in optional_params assert "max_retries" not in optional_params
@pytest.mark.parametrize( @pytest.mark.parametrize("provider", ["anthropic", "vertex_ai"])
"tools, key", def test_anthropic_parallel_tool_calls(provider):
[ optional_params = get_optional_params(
([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"), model="claude-3-5-sonnet-v250@20241022",
([{"code_execution": {}}], "code_execution"), custom_llm_provider=provider,
], parallel_tool_calls=True,
) )
def test_vertex_tool_params(tools, key): print(f"optional_params: {optional_params}")
assert optional_params["tool_choice"]["disable_parallel_tool_use"] is True
def test_anthropic_computer_tool_use():
tools = [
{
"type": "computer_20241022",
"function": {
"name": "computer",
"parameters": {
"display_height_px": 100,
"display_width_px": 100,
"display_number": 1,
},
},
}
]
optional_params = get_optional_params( optional_params = get_optional_params(
model="gemini-1.5-pro", model="claude-3-5-sonnet-v250@20241022",
custom_llm_provider="anthropic",
tools=tools,
)
assert optional_params["tools"][0]["type"] == "computer_20241022"
assert optional_params["tools"][0]["display_height_px"] == 100
assert optional_params["tools"][0]["display_width_px"] == 100
assert optional_params["tools"][0]["display_number"] == 1
def test_vertex_schema_field():
tools = [
{
"type": "function",
"function": {
"name": "json",
"description": "Respond with a JSON object.",
"parameters": {
"type": "object",
"properties": {
"thinking": {
"type": "string",
"description": "Your internal thoughts on different problem details given the guidance.",
},
"problems": {
"type": "array",
"items": {
"type": "object",
"properties": {
"icon": {
"type": "string",
"enum": [
"BarChart2",
"Bell",
],
"description": "The name of a Lucide icon to display",
},
"color": {
"type": "string",
"description": "A Tailwind color class for the icon, e.g., 'text-red-500'",
},
"problem": {
"type": "string",
"description": "The title of the problem being addressed, approximately 3-5 words.",
},
"description": {
"type": "string",
"description": "A brief explanation of the problem, approximately 20 words.",
},
"impacts": {
"type": "array",
"items": {"type": "string"},
"description": "A list of potential impacts or consequences of the problem, approximately 3 words each.",
},
"automations": {
"type": "array",
"items": {"type": "string"},
"description": "A list of potential automations to address the problem, approximately 3-5 words each.",
},
},
"required": [
"icon",
"color",
"problem",
"description",
"impacts",
"automations",
],
"additionalProperties": False,
},
"description": "Please generate problem cards that match this guidance.",
},
},
"required": ["thinking", "problems"],
"additionalProperties": False,
"$schema": "http://json-schema.org/draft-07/schema#",
},
},
}
]
optional_params = get_optional_params(
model="gemini-1.5-flash",
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
tools=tools, tools=tools,
) )
print(optional_params) print(optional_params)
assert optional_params["tools"][0][key] == {} print(optional_params["tools"][0]["function_declarations"][0])
assert (
"$schema"
not in optional_params["tools"][0]["function_declarations"][0]["parameters"]
)