From 6b9be5092f329f9d308106df9333f2c11d605d5b Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 29 Oct 2024 17:20:24 -0700 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (10/28/2024) (#6475) * fix(anthropic/chat/transformation.py): support anthropic disable_parallel_tool_use param Fixes https://github.com/BerriAI/litellm/issues/6456 * feat(anthropic/chat/transformation.py): support anthropic computer tool use Closes https://github.com/BerriAI/litellm/issues/6427 * fix(vertex_ai/common_utils.py): parse out '$schema' when calling vertex ai Fixes issue when trying to call vertex from vercel sdk * fix(main.py): add 'extra_headers' support for azure on all translation endpoints Fixes https://github.com/BerriAI/litellm/issues/6465 * fix: fix linting errors * fix(transformation.py): handle no beta headers for anthropic * test: cleanup test * fix: fix linting error * fix: fix linting errors * fix: fix linting errors * fix(transformation.py): handle dummy tool call * fix(main.py): fix linting error * fix(azure.py): pass required param * LiteLLM Minor Fixes & Improvements (10/24/2024) (#6441) * fix(azure.py): handle /openai/deployment in azure api base * fix(factory.py): fix faulty anthropic tool result translation check Fixes https://github.com/BerriAI/litellm/issues/6422 * fix(gpt_transformation.py): add support for parallel_tool_calls to azure Fixes https://github.com/BerriAI/litellm/issues/6440 * fix(factory.py): support anthropic prompt caching for tool results * fix(vertex_ai/common_utils): don't pop non-null required field Fixes https://github.com/BerriAI/litellm/issues/6426 * feat(vertex_ai.py): support code_execution tool call for vertex ai + gemini Closes https://github.com/BerriAI/litellm/issues/6434 * build(model_prices_and_context_window.json): Add 'supports_assistant_prefill' for bedrock claude-3-5-sonnet v2 models Closes https://github.com/BerriAI/litellm/issues/6437 * fix(types/utils.py): fix linting * test: update test to include required fields * test: fix test * test: handle flaky test * test: remove e2e test - hitting gemini rate limits * Litellm dev 10 26 2024 (#6472) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * (Testing) Add unit testing for DualCache - ensure in memory cache is used when expected (#6471) * test test_dual_cache_get_set * unit testing for dual cache * fix async_set_cache_sadd * test_dual_cache_local_only * redis otel tracing + async support for latency routing (#6452) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * refactor: pass parent_otel_span for redis caching calls in router allows for more observability into what calls are causing latency issues * test: update tests with new params * refactor: ensure e2e otel tracing for router * refactor(router.py): add more otel tracing acrosss router catch all latency issues for router requests * fix: fix linting error * fix(router.py): fix linting error * fix: fix test * test: fix tests * fix(dual_cache.py): pass ttl to redis cache * fix: fix param * fix(dual_cache.py): set default value for parent_otel_span * fix(transformation.py): support 'response_format' for anthropic calls * fix(transformation.py): check for cache_control inside 'function' block * fix: fix linting error * fix: fix linting errors --------- Co-authored-by: Ishaan Jaff --- docs/my-website/docs/providers/anthropic.md | 31 +++ litellm/caching/dual_cache.py | 4 +- litellm/llms/AzureOpenAI/azure.py | 50 ++-- litellm/llms/anthropic/chat/handler.py | 36 ++- litellm/llms/anthropic/chat/transformation.py | 223 +++++++++++++++--- .../transformation.py | 15 +- .../common_utils.py | 15 +- .../vertex_and_google_ai_studio_gemini.py | 14 +- .../anthropic/transformation.py | 113 +-------- litellm/main.py | 29 ++- ...odel_prices_and_context_window_backup.json | 14 ++ litellm/router.py | 40 +--- litellm/router_utils/cooldown_cache.py | 8 - litellm/router_utils/handle_error.py | 34 ++- litellm/types/llms/anthropic.py | 25 +- litellm/types/llms/openai.py | 8 +- .../test_anthropic_completion.py | 95 ++++++++ tests/llm_translation/test_azure_openai.py | 60 +++++ tests/llm_translation/test_optional_params.py | 123 +++++++++- 19 files changed, 684 insertions(+), 253 deletions(-) diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index 89703512f..0c7b2a442 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -721,6 +721,37 @@ except Exception as e: s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this! +### Computer Tools + +```python +from litellm import completion + +tools = [ + { + "type": "computer_20241022", + "function": { + "name": "computer", + "parameters": { + "display_height_px": 100, + "display_width_px": 100, + "display_number": 1, + }, + }, + } +] +model = "claude-3-5-sonnet-20241022" +messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}] + +resp = completion( + model=model, + messages=messages, + tools=tools, + # headers={"anthropic-beta": "computer-use-2024-10-22"}, +) + +print(resp) +``` + ## Usage - Vision ```python diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 1bf16bb65..ef168f65f 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -152,7 +152,7 @@ class DualCache(BaseCache): def batch_get_cache( self, keys: list, - parent_otel_span: Optional[Span], + parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ): @@ -343,7 +343,7 @@ class DualCache(BaseCache): self, key, value: float, - parent_otel_span: Optional[Span], + parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ) -> float: diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index 74823596b..39dea14e2 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -961,6 +961,7 @@ class AzureChatCompletion(BaseLLM): api_version: str, api_key: str, data: dict, + headers: dict, ) -> httpx.Response: """ Implemented for azure dall-e-2 image gen calls @@ -1002,10 +1003,7 @@ class AzureChatCompletion(BaseLLM): response = await async_handler.post( url=api_base, data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "api-key": api_key, - }, + headers=headers, ) if "operation-location" in response.headers: operation_location_url = response.headers["operation-location"] @@ -1013,9 +1011,7 @@ class AzureChatCompletion(BaseLLM): raise AzureOpenAIError(status_code=500, message=response.text) response = await async_handler.get( url=operation_location_url, - headers={ - "api-key": api_key, - }, + headers=headers, ) await response.aread() @@ -1036,9 +1032,7 @@ class AzureChatCompletion(BaseLLM): await asyncio.sleep(int(response.headers.get("retry-after") or 10)) response = await async_handler.get( url=operation_location_url, - headers={ - "api-key": api_key, - }, + headers=headers, ) await response.aread() @@ -1056,10 +1050,7 @@ class AzureChatCompletion(BaseLLM): return await async_handler.post( url=api_base, json=data, - headers={ - "Content-Type": "application/json;", - "api-key": api_key, - }, + headers=headers, ) def make_sync_azure_httpx_request( @@ -1070,6 +1061,7 @@ class AzureChatCompletion(BaseLLM): api_version: str, api_key: str, data: dict, + headers: dict, ) -> httpx.Response: """ Implemented for azure dall-e-2 image gen calls @@ -1085,7 +1077,7 @@ class AzureChatCompletion(BaseLLM): else: _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - sync_handler = HTTPHandler(**_params) # type: ignore + sync_handler = HTTPHandler(**_params, client=litellm.client_session) # type: ignore else: sync_handler = client # type: ignore @@ -1111,10 +1103,7 @@ class AzureChatCompletion(BaseLLM): response = sync_handler.post( url=api_base, data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "api-key": api_key, - }, + headers=headers, ) if "operation-location" in response.headers: operation_location_url = response.headers["operation-location"] @@ -1122,9 +1111,7 @@ class AzureChatCompletion(BaseLLM): raise AzureOpenAIError(status_code=500, message=response.text) response = sync_handler.get( url=operation_location_url, - headers={ - "api-key": api_key, - }, + headers=headers, ) response.read() @@ -1144,9 +1131,7 @@ class AzureChatCompletion(BaseLLM): time.sleep(int(response.headers.get("retry-after") or 10)) response = sync_handler.get( url=operation_location_url, - headers={ - "api-key": api_key, - }, + headers=headers, ) response.read() @@ -1164,10 +1149,7 @@ class AzureChatCompletion(BaseLLM): return sync_handler.post( url=api_base, json=data, - headers={ - "Content-Type": "application/json;", - "api-key": api_key, - }, + headers=headers, ) def create_azure_base_url( @@ -1200,6 +1182,7 @@ class AzureChatCompletion(BaseLLM): api_key: str, input: list, logging_obj: LiteLLMLoggingObj, + headers: dict, client=None, timeout=None, ) -> litellm.ImageResponse: @@ -1223,7 +1206,7 @@ class AzureChatCompletion(BaseLLM): additional_args={ "complete_input_dict": data, "api_base": img_gen_api_base, - "headers": {"api_key": api_key}, + "headers": headers, }, ) httpx_response: httpx.Response = await self.make_async_azure_httpx_request( @@ -1233,6 +1216,7 @@ class AzureChatCompletion(BaseLLM): api_version=api_version, api_key=api_key, data=data, + headers=headers, ) response = httpx_response.json() @@ -1265,6 +1249,7 @@ class AzureChatCompletion(BaseLLM): timeout: float, optional_params: dict, logging_obj: LiteLLMLoggingObj, + headers: dict, model: Optional[str] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, @@ -1315,7 +1300,7 @@ class AzureChatCompletion(BaseLLM): azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation is True: - return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore + return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore img_gen_api_base = self.create_azure_base_url( azure_client_params=azure_client_params, model=data.get("model", "") @@ -1328,7 +1313,7 @@ class AzureChatCompletion(BaseLLM): additional_args={ "complete_input_dict": data, "api_base": img_gen_api_base, - "headers": {"api_key": api_key}, + "headers": headers, }, ) httpx_response: httpx.Response = self.make_sync_azure_httpx_request( @@ -1338,6 +1323,7 @@ class AzureChatCompletion(BaseLLM): api_version=api_version or "", api_key=api_key or "", data=data, + headers=headers, ) response = httpx_response.json() diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 03068537b..a30cd6570 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -29,6 +29,7 @@ from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, ) from litellm.types.llms.anthropic import ( + AllAnthropicToolsValues, AnthropicChatCompletionUsageBlock, ContentBlockDelta, ContentBlockStart, @@ -53,9 +54,14 @@ from .transformation import AnthropicConfig # makes headers for API call def validate_environment( - api_key, user_headers, model, messages: List[AllMessageValues] + api_key, + user_headers, + model, + messages: List[AllMessageValues], + tools: Optional[List[AllAnthropicToolsValues]], + anthropic_version: Optional[str] = None, ): - cache_headers = {} + if api_key is None: raise litellm.AuthenticationError( message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars", @@ -63,17 +69,15 @@ def validate_environment( model=model, ) - if AnthropicConfig().is_cache_control_set(messages=messages): - cache_headers = AnthropicConfig().get_cache_control_headers() + prompt_caching_set = AnthropicConfig().is_cache_control_set(messages=messages) + computer_tool_used = AnthropicConfig().is_computer_tool_used(tools=tools) - headers = { - "accept": "application/json", - "anthropic-version": "2023-06-01", - "content-type": "application/json", - "x-api-key": api_key, - } - - headers.update(cache_headers) + headers = AnthropicConfig().get_anthropic_headers( + anthropic_version=anthropic_version, + computer_tool_used=computer_tool_used, + prompt_caching_set=prompt_caching_set, + api_key=api_key, + ) if user_headers is not None and isinstance(user_headers, dict): headers = {**headers, **user_headers} @@ -441,7 +445,13 @@ class AnthropicChatCompletion(BaseLLM): headers={}, client=None, ): - headers = validate_environment(api_key, headers, model, messages=messages) + headers = validate_environment( + api_key, + headers, + model, + messages=messages, + tools=optional_params.get("tools"), + ) _is_function_call = False messages = copy.deepcopy(messages) optional_params = copy.deepcopy(optional_params) diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index b5d9a1aa6..4953a393a 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -4,6 +4,9 @@ from typing import List, Literal, Optional, Tuple, Union import litellm from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.types.llms.anthropic import ( + AllAnthropicToolsValues, + AnthropicComputerTool, + AnthropicHostedTools, AnthropicMessageRequestBase, AnthropicMessagesRequest, AnthropicMessagesTool, @@ -12,6 +15,7 @@ from litellm.types.llms.anthropic import ( ) from litellm.types.llms.openai import ( AllMessageValues, + ChatCompletionCachedContent, ChatCompletionSystemMessage, ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, @@ -84,6 +88,8 @@ class AnthropicConfig: "tools", "tool_choice", "extra_headers", + "parallel_tool_calls", + "response_format", ] def get_cache_control_headers(self) -> dict: @@ -92,6 +98,146 @@ class AnthropicConfig: "anthropic-beta": "prompt-caching-2024-07-31", } + def get_anthropic_headers( + self, + api_key: str, + anthropic_version: Optional[str] = None, + computer_tool_used: bool = False, + prompt_caching_set: bool = False, + ) -> dict: + import json + + betas = [] + if prompt_caching_set: + betas.append("prompt-caching-2024-07-31") + if computer_tool_used: + betas.append("computer-use-2024-10-22") + headers = { + "anthropic-version": anthropic_version or "2023-06-01", + "x-api-key": api_key, + "accept": "application/json", + "content-type": "application/json", + } + if len(betas) > 0: + headers["anthropic-beta"] = ",".join(betas) + return headers + + def _map_tool_choice( + self, tool_choice: Optional[str], disable_parallel_tool_use: Optional[bool] + ) -> Optional[AnthropicMessagesToolChoice]: + _tool_choice: Optional[AnthropicMessagesToolChoice] = None + if tool_choice == "auto": + _tool_choice = AnthropicMessagesToolChoice( + type="auto", + ) + elif tool_choice == "required": + _tool_choice = AnthropicMessagesToolChoice(type="any") + elif isinstance(tool_choice, dict): + _tool_name = tool_choice.get("function", {}).get("name") + _tool_choice = AnthropicMessagesToolChoice(type="tool") + if _tool_name is not None: + _tool_choice["name"] = _tool_name + + if disable_parallel_tool_use is not None: + if _tool_choice is not None: + _tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use + else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag + _tool_choice = AnthropicMessagesToolChoice( + type="auto", + disable_parallel_tool_use=disable_parallel_tool_use, + ) + return _tool_choice + + def _map_tool_helper( + self, tool: ChatCompletionToolParam + ) -> AllAnthropicToolsValues: + returned_tool: Optional[AllAnthropicToolsValues] = None + + if tool["type"] == "function" or tool["type"] == "custom": + _tool = AnthropicMessagesTool( + name=tool["function"]["name"], + input_schema=tool["function"].get( + "parameters", + { + "type": "object", + "properties": {}, + }, + ), + ) + + _description = tool["function"].get("description") + if _description is not None: + _tool["description"] = _description + + returned_tool = _tool + + elif tool["type"].startswith("computer_"): + ## check if all required 'display_' params are given + if "parameters" not in tool["function"]: + raise ValueError("Missing required parameter: parameters") + + _display_width_px: Optional[int] = tool["function"]["parameters"].get( + "display_width_px" + ) + _display_height_px: Optional[int] = tool["function"]["parameters"].get( + "display_height_px" + ) + if _display_width_px is None or _display_height_px is None: + raise ValueError( + "Missing required parameter: display_width_px or display_height_px" + ) + + _computer_tool = AnthropicComputerTool( + type=tool["type"], + name=tool["function"].get("name", "computer"), + display_width_px=_display_width_px, + display_height_px=_display_height_px, + ) + + _display_number = tool["function"]["parameters"].get("display_number") + if _display_number is not None: + _computer_tool["display_number"] = _display_number + + returned_tool = _computer_tool + elif tool["type"].startswith("bash_") or tool["type"].startswith( + "text_editor_" + ): + function_name = tool["function"].get("name") + if function_name is None: + raise ValueError("Missing required parameter: name") + + returned_tool = AnthropicHostedTools( + type=tool["type"], + name=function_name, + ) + if returned_tool is None: + raise ValueError(f"Unsupported tool type: {tool['type']}") + + ## check if cache_control is set in the tool + _cache_control = tool.get("cache_control", None) + _cache_control_function = tool.get("function", {}).get("cache_control", None) + if _cache_control is not None: + returned_tool["cache_control"] = _cache_control + elif _cache_control_function is not None and isinstance( + _cache_control_function, dict + ): + returned_tool["cache_control"] = ChatCompletionCachedContent( + **_cache_control_function # type: ignore + ) + + return returned_tool + + def _map_tools(self, tools: List) -> List[AllAnthropicToolsValues]: + anthropic_tools = [] + for tool in tools: + if "input_schema" in tool: # assume in anthropic format + anthropic_tools.append(tool) + else: # assume openai tool call + new_tool = self._map_tool_helper(tool) + + anthropic_tools.append(new_tool) + return anthropic_tools + def map_openai_params( self, non_default_params: dict, @@ -104,15 +250,16 @@ class AnthropicConfig: if param == "max_completion_tokens": optional_params["max_tokens"] = value if param == "tools": - optional_params["tools"] = value - if param == "tool_choice": - _tool_choice: Optional[AnthropicMessagesToolChoice] = None - if value == "auto": - _tool_choice = {"type": "auto"} - elif value == "required": - _tool_choice = {"type": "any"} - elif isinstance(value, dict): - _tool_choice = {"type": "tool", "name": value["function"]["name"]} + optional_params["tools"] = self._map_tools(value) + if param == "tool_choice" or param == "parallel_tool_calls": + _tool_choice: Optional[AnthropicMessagesToolChoice] = ( + self._map_tool_choice( + tool_choice=non_default_params.get("tool_choice"), + disable_parallel_tool_use=non_default_params.get( + "parallel_tool_calls" + ), + ) + ) if _tool_choice is not None: optional_params["tool_choice"] = _tool_choice @@ -142,6 +289,32 @@ class AnthropicConfig: optional_params["temperature"] = value if param == "top_p": optional_params["top_p"] = value + if param == "response_format" and isinstance(value, dict): + json_schema: Optional[dict] = None + if "response_schema" in value: + json_schema = value["response_schema"] + elif "json_schema" in value: + json_schema = value["json_schema"]["schema"] + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + _tool_choice = None + _tool_choice = {"name": "json_tool_call", "type": "tool"} + + _tool = AnthropicMessagesTool( + name="json_tool_call", + input_schema={ + "type": "object", + "properties": {"values": json_schema}, # type: ignore + }, + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True ## VALIDATE REQUEST """ @@ -153,8 +326,8 @@ class AnthropicConfig: and has_tool_call_blocks(messages) ): if litellm.modify_params: - optional_params["tools"] = add_dummy_tool( - custom_llm_provider="bedrock_converse" + optional_params["tools"] = self._map_tools( + add_dummy_tool(custom_llm_provider="anthropic") ) else: raise litellm.UnsupportedParamsError( @@ -182,6 +355,16 @@ class AnthropicConfig: return False + def is_computer_tool_used( + self, tools: Optional[List[AllAnthropicToolsValues]] + ) -> bool: + if tools is None: + return False + for tool in tools: + if "type" in tool and tool["type"].startswith("computer_"): + return True + return False + def translate_system_message( self, messages: List[AllMessageValues] ) -> List[AnthropicSystemMessageContent]: @@ -276,24 +459,6 @@ class AnthropicConfig: ## Handle Tool Calling if "tools" in optional_params: _is_function_call = True - anthropic_tools = [] - for tool in optional_params["tools"]: - if "input_schema" in tool: # assume in anthropic format - anthropic_tools.append(tool) - else: # assume openai tool call - new_tool = tool["function"] - parameters = new_tool.pop( - "parameters", - { - "type": "object", - "properties": {}, - }, - ) - new_tool["input_schema"] = parameters # rename key - if "cache_control" in tool: - new_tool["cache_control"] = tool["cache_control"] - anthropic_tools.append(new_tool) - optional_params["tools"] = anthropic_tools data = { "messages": anthropic_messages, diff --git a/litellm/llms/anthropic/experimental_pass_through/transformation.py b/litellm/llms/anthropic/experimental_pass_through/transformation.py index 0f9a31f83..8d77c40af 100644 --- a/litellm/llms/anthropic/experimental_pass_through/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/transformation.py @@ -6,9 +6,12 @@ from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingCho import litellm from litellm.types.llms.anthropic import ( + AllAnthropicToolsValues, AnthopicMessagesAssistantMessageParam, AnthropicChatCompletionUsageBlock, + AnthropicComputerTool, AnthropicFinishReason, + AnthropicHostedTools, AnthropicMessagesRequest, AnthropicMessagesTool, AnthropicMessagesToolChoice, @@ -215,16 +218,22 @@ class AnthropicExperimentalPassThroughConfig: ) def translate_anthropic_tools_to_openai( - self, tools: List[AnthropicMessagesTool] + self, tools: List[AllAnthropicToolsValues] ) -> List[ChatCompletionToolParam]: new_tools: List[ChatCompletionToolParam] = [] + mapped_tool_params = ["name", "input_schema", "description"] for tool in tools: function_chunk = ChatCompletionToolParamFunctionChunk( name=tool["name"], - parameters=tool["input_schema"], ) + if "input_schema" in tool: + function_chunk["parameters"] = tool["input_schema"] # type: ignore if "description" in tool: - function_chunk["description"] = tool["description"] + function_chunk["description"] = tool["description"] # type: ignore + + for k, v in tool.items(): + if k not in mapped_tool_params: # pass additional computer kwargs + function_chunk.setdefault("parameters", {}).update({k: v}) new_tools.append( ChatCompletionToolParam(type="function", function=function_chunk) ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 71b2bbc01..0f95b222c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -164,7 +164,12 @@ def _build_vertex_schema(parameters: dict): # 4. Suppress unnecessary title generation: # * https://github.com/pydantic/pydantic/issues/1051 # * http://cl/586221780 - strip_titles(parameters) + strip_field(parameters, field_name="title") + + strip_field( + parameters, field_name="$schema" + ) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors. + return parameters @@ -245,14 +250,14 @@ def add_object_type(schema): add_object_type(items) -def strip_titles(schema): - schema.pop("title", None) +def strip_field(schema, field_name: str): + schema.pop(field_name, None) properties = schema.get("properties", None) if properties is not None: for name, value in properties.items(): - strip_titles(value) + strip_field(value, field_name) items = schema.get("items", None) if items is not None: - strip_titles(items) + strip_field(items, field_name) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index a6e1d782a..914651f3d 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -400,14 +400,26 @@ class VertexGeminiConfig: value = _remove_additional_properties(value) # remove 'strict' from tools value = _remove_strict_from_schema(value) + for tool in value: openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( None ) if "function" in tool: # tools list - openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore + _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore **tool["function"] ) + + if ( + "parameters" in _openai_function_object + and _openai_function_object["parameters"] is not None + ): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema. + _openai_function_object["parameters"] = _build_vertex_schema( + _openai_function_object["parameters"] + ) + + openai_function_object = _openai_function_object + elif "name" in tool: # functions list openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py index 406314a59..0c3d3965d 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py @@ -15,10 +15,6 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import ( - AnthropicMessagesTool, - AnthropicMessagesToolChoice, -) from litellm.types.llms.openai import ( ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, @@ -26,6 +22,7 @@ from litellm.types.llms.openai import ( from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage +from ....anthropic.chat.transformation import AnthropicConfig from ....prompt_templates.factory import ( construct_tool_use_system_prompt, contains_tag, @@ -50,7 +47,7 @@ class VertexAIError(Exception): ) # Call the base class constructor with the parameters it needs -class VertexAIAnthropicConfig: +class VertexAIAnthropicConfig(AnthropicConfig): """ Reference:https://docs.anthropic.com/claude/reference/messages_post @@ -72,112 +69,6 @@ class VertexAIAnthropicConfig: Note: Please make sure to modify the default parameters as required for your use case. """ - max_tokens: Optional[int] = ( - 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic. - ) - system: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[int] = None - stop_sequences: Optional[List[str]] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - anthropic_version: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key == "max_tokens" and value is None: - value = self.max_tokens - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "max_tokens", - "max_completion_tokens", - "tools", - "tool_choice", - "stream", - "stop", - "temperature", - "top_p", - "response_format", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens" or param == "max_completion_tokens": - optional_params["max_tokens"] = value - if param == "tools": - optional_params["tools"] = value - if param == "tool_choice": - _tool_choice: Optional[AnthropicMessagesToolChoice] = None - if value == "auto": - _tool_choice = {"type": "auto"} - elif value == "required": - _tool_choice = {"type": "any"} - elif isinstance(value, dict): - _tool_choice = {"type": "tool", "name": value["function"]["name"]} - - if _tool_choice is not None: - optional_params["tool_choice"] = _tool_choice - if param == "stream": - optional_params["stream"] = value - if param == "stop": - optional_params["stop_sequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "response_format" and isinstance(value, dict): - json_schema: Optional[dict] = None - if "response_schema" in value: - json_schema = value["response_schema"] - elif "json_schema" in value: - json_schema = value["json_schema"]["schema"] - """ - When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode - - You usually want to provide a single tool - - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. - """ - _tool_choice = None - _tool_choice = {"name": "json_tool_call", "type": "tool"} - - _tool = AnthropicMessagesTool( - name="json_tool_call", - input_schema={ - "type": "object", - "properties": {"values": json_schema}, # type: ignore - }, - ) - - optional_params["tools"] = [_tool] - optional_params["tool_choice"] = _tool_choice - optional_params["json_mode"] = True - - return optional_params - @classmethod def is_supported_model( cls, model: str, custom_llm_provider: Optional[str] = None diff --git a/litellm/main.py b/litellm/main.py index 6829de677..30ff47e88 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3377,6 +3377,9 @@ def embedding( # noqa: PLR0915 "azure_ad_token", None ) or get_secret_str("AZURE_AD_TOKEN") + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + api_key = ( api_key or litellm.api_key @@ -4458,7 +4461,10 @@ def image_generation( # noqa: PLR0915 metadata = kwargs.get("metadata", {}) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore client = kwargs.get("client", None) - + extra_headers = kwargs.get("extra_headers", None) + headers: dict = kwargs.get("headers", None) or {} + if extra_headers is not None: + headers.update(extra_headers) model_response: ImageResponse = litellm.utils.ImageResponse() if model is not None or custom_llm_provider is not None: model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore @@ -4589,6 +4595,14 @@ def image_generation( # noqa: PLR0915 "azure_ad_token", None ) or get_secret_str("AZURE_AD_TOKEN") + default_headers = { + "Content-Type": "application/json;", + "api-key": api_key, + } + for k, v in default_headers.items(): + if k not in headers: + headers[k] = v + model_response = azure_chat_completions.image_generation( model=model, prompt=prompt, @@ -4601,6 +4615,7 @@ def image_generation( # noqa: PLR0915 api_version=api_version, aimg_generation=aimg_generation, client=client, + headers=headers, ) elif custom_llm_provider == "openai": model_response = openai_chat_completions.image_generation( @@ -4797,11 +4812,7 @@ def transcription( """ atranscription = kwargs.get("atranscription", False) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore - kwargs.get("litellm_call_id", None) - kwargs.get("logger_fn", None) - kwargs.get("proxy_server_request", None) - kwargs.get("model_info", None) - kwargs.get("metadata", {}) + extra_headers = kwargs.get("extra_headers", None) kwargs.pop("tags", []) drop_params = kwargs.get("drop_params", None) @@ -4857,6 +4868,8 @@ def transcription( or get_secret_str("AZURE_API_KEY") ) + optional_params["extra_headers"] = extra_headers + response = azure_audio_transcriptions.audio_transcriptions( model=model, audio_file=file, @@ -4975,6 +4988,7 @@ def speech( user = kwargs.get("user", None) litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) proxy_server_request = kwargs.get("proxy_server_request", None) + extra_headers = kwargs.get("extra_headers", None) model_info = kwargs.get("model_info", None) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore kwargs.pop("tags", []) @@ -5087,7 +5101,8 @@ def speech( "AZURE_AD_TOKEN" ) - headers = headers or litellm.headers + if extra_headers: + optional_params["extra_headers"] = extra_headers response = azure_chat_completions.audio_speech( model=model, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 8f833c129..9578ed9ea 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -774,6 +774,20 @@ "supports_vision": true, "supports_prompt_caching": true }, + "azure/gpt-4o-mini-2024-07-18": { + "max_tokens": 16384, + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "input_cost_per_token": 0.000000165, + "output_cost_per_token": 0.00000066, + "cache_read_input_token_cost": 0.000000075, + "litellm_provider": "azure", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_vision": true, + "supports_prompt_caching": true + }, "azure/gpt-4-turbo-2024-04-09": { "max_tokens": 4096, "max_input_tokens": 128000, diff --git a/litellm/router.py b/litellm/router.py index ac26aa61e..cc8ad7434 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -81,7 +81,10 @@ from litellm.router_utils.fallback_event_handlers import ( run_async_fallback, run_sync_fallback, ) -from litellm.router_utils.handle_error import send_llm_exception_alert +from litellm.router_utils.handle_error import ( + async_raise_no_deployment_exception, + send_llm_exception_alert, +) from litellm.router_utils.router_callbacks.track_deployment_metrics import ( increment_deployment_failures_for_current_minute, increment_deployment_successes_for_current_minute, @@ -5183,21 +5186,12 @@ class Router: ) if len(healthy_deployments) == 0: - if _allowed_model_region is None: - _allowed_model_region = "n/a" - model_ids = self.get_model_ids(model_name=model) - _cooldown_time = self.cooldown_cache.get_min_cooldown( - model_ids=model_ids, parent_otel_span=parent_otel_span - ) - _cooldown_list = _get_cooldown_deployments( - litellm_router_instance=self, parent_otel_span=parent_otel_span - ) - raise RouterRateLimitError( + exception = await async_raise_no_deployment_exception( + litellm_router_instance=self, model=model, - cooldown_time=_cooldown_time, - enable_pre_call_checks=self.enable_pre_call_checks, - cooldown_list=_cooldown_list, + parent_otel_span=parent_otel_span, ) + raise exception start_time = time.time() if ( self.routing_strategy == "usage-based-routing-v2" @@ -5255,22 +5249,12 @@ class Router: else: deployment = None if deployment is None: - verbose_router_logger.info( - f"get_available_deployment for model: {model}, No deployment available" - ) - model_ids = self.get_model_ids(model_name=model) - _cooldown_time = self.cooldown_cache.get_min_cooldown( - model_ids=model_ids, parent_otel_span=parent_otel_span - ) - _cooldown_list = await _async_get_cooldown_deployments( - litellm_router_instance=self, parent_otel_span=parent_otel_span - ) - raise RouterRateLimitError( + exception = await async_raise_no_deployment_exception( + litellm_router_instance=self, model=model, - cooldown_time=_cooldown_time, - enable_pre_call_checks=self.enable_pre_call_checks, - cooldown_list=_cooldown_list, + parent_otel_span=parent_otel_span, ) + raise exception verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index 44174f3b1..dbe767214 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -17,13 +17,6 @@ if TYPE_CHECKING: else: Span = Any -if TYPE_CHECKING: - from opentelemetry.trace import Span as _Span - - Span = _Span -else: - Span = Any - class CooldownCacheValue(TypedDict): exception_received: str @@ -117,7 +110,6 @@ class CooldownCache: if results is None: return active_cooldowns - # Process the results for model_id, result in zip(model_ids, results): if result and isinstance(result, dict): diff --git a/litellm/router_utils/handle_error.py b/litellm/router_utils/handle_error.py index 25b511027..321ba5dc5 100644 --- a/litellm/router_utils/handle_error.py +++ b/litellm/router_utils/handle_error.py @@ -1,15 +1,22 @@ import asyncio import traceback -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional +from litellm._logging import verbose_router_logger +from litellm.router_utils.cooldown_handlers import _async_get_cooldown_deployments from litellm.types.integrations.slack_alerting import AlertType +from litellm.types.router import RouterRateLimitError if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from litellm.router import Router as _Router LitellmRouter = _Router + Span = _Span else: LitellmRouter = Any + Span = Any async def send_llm_exception_alert( @@ -55,3 +62,28 @@ async def send_llm_exception_alert( alert_type=AlertType.llm_exceptions, alerting_metadata={}, ) + + +async def async_raise_no_deployment_exception( + litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span] +): + """ + Raises a RouterRateLimitError if no deployment is found for the given model. + """ + verbose_router_logger.info( + f"get_available_deployment for model: {model}, No deployment available" + ) + model_ids = litellm_router_instance.get_model_ids(model_name=model) + _cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = await _async_get_cooldown_deployments( + litellm_router_instance=litellm_router_instance, + parent_otel_span=parent_otel_span, + ) + return RouterRateLimitError( + model=model, + cooldown_time=_cooldown_time, + enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks, + cooldown_list=_cooldown_list, + ) diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 4a1e63f7d..bb65a372d 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -9,12 +9,35 @@ from .openai import ChatCompletionCachedContent class AnthropicMessagesToolChoice(TypedDict, total=False): type: Required[Literal["auto", "any", "tool"]] name: str + disable_parallel_tool_use: bool # default is false class AnthropicMessagesTool(TypedDict, total=False): name: Required[str] description: str input_schema: Required[dict] + type: Literal["custom"] + cache_control: Optional[Union[dict, ChatCompletionCachedContent]] + + +class AnthropicComputerTool(TypedDict, total=False): + display_width_px: Required[int] + display_height_px: Required[int] + display_number: int + cache_control: Optional[Union[dict, ChatCompletionCachedContent]] + type: Required[str] + name: Required[str] + + +class AnthropicHostedTools(TypedDict, total=False): # for bash_tool and text_editor + type: Required[str] + name: Required[str] + cache_control: Optional[Union[dict, ChatCompletionCachedContent]] + + +AllAnthropicToolsValues = Union[ + AnthropicComputerTool, AnthropicHostedTools, AnthropicMessagesTool +] class AnthropicMessagesTextParam(TypedDict, total=False): @@ -117,7 +140,7 @@ class AnthropicMessageRequestBase(TypedDict, total=False): system: Union[str, List] temperature: float tool_choice: AnthropicMessagesToolChoice - tools: List[AnthropicMessagesTool] + tools: List[AllAnthropicToolsValues] top_k: int top_p: float diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 3897aa2ee..3b95a3282 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -440,11 +440,15 @@ class ChatCompletionToolParamFunctionChunk(TypedDict, total=False): parameters: dict -class ChatCompletionToolParam(TypedDict): - type: Literal["function"] +class OpenAIChatCompletionToolParam(TypedDict): + type: Union[Literal["function"], str] function: ChatCompletionToolParamFunctionChunk +class ChatCompletionToolParam(OpenAIChatCompletionToolParam, total=False): + cache_control: ChatCompletionCachedContent + + class Function(TypedDict, total=False): name: Required[str] """The name of the function to call.""" diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index 9ce951b6c..6be41f90d 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -527,3 +527,98 @@ def test_process_anthropic_headers_with_no_matching_headers(): result = process_anthropic_headers(input_headers) assert result == expected_output, "Unexpected output for non-matching headers" + + +def test_anthropic_computer_tool_use(): + from litellm import completion + + tools = [ + { + "type": "computer_20241022", + "function": { + "name": "computer", + "parameters": { + "display_height_px": 100, + "display_width_px": 100, + "display_number": 1, + }, + }, + } + ] + model = "claude-3-5-sonnet-20241022" + messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}] + + resp = completion( + model=model, + messages=messages, + tools=tools, + # headers={"anthropic-beta": "computer-use-2024-10-22"}, + ) + + print(resp) + + +@pytest.mark.parametrize( + "computer_tool_used, prompt_caching_set, expected_beta_header", + [ + (True, False, True), + (False, True, True), + (True, True, True), + (False, False, False), + ], +) +def test_anthropic_beta_header( + computer_tool_used, prompt_caching_set, expected_beta_header +): + headers = litellm.AnthropicConfig().get_anthropic_headers( + api_key="fake-api-key", + computer_tool_used=computer_tool_used, + prompt_caching_set=prompt_caching_set, + ) + + if expected_beta_header: + assert "anthropic-beta" in headers + else: + assert "anthropic-beta" not in headers + + +@pytest.mark.parametrize( + "cache_control_location", + [ + "inside_function", + "outside_function", + ], +) +def test_anthropic_tool_helper(cache_control_location): + from litellm.llms.anthropic.chat.transformation import AnthropicConfig + + tool = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + + if cache_control_location == "inside_function": + tool["function"]["cache_control"] = {"type": "ephemeral"} + else: + tool["cache_control"] = {"type": "ephemeral"} + + tool = AnthropicConfig()._map_tool_helper(tool=tool) + + assert tool["cache_control"] == {"type": "ephemeral"} diff --git a/tests/llm_translation/test_azure_openai.py b/tests/llm_translation/test_azure_openai.py index fdb730dce..837714770 100644 --- a/tests/llm_translation/test_azure_openai.py +++ b/tests/llm_translation/test_azure_openai.py @@ -96,6 +96,66 @@ def test_process_azure_headers_with_dict_input(): assert result == expected_output, "Unexpected output for dict input" +from httpx import Client +from unittest.mock import MagicMock, patch +from openai import AzureOpenAI +import litellm +from litellm import completion +import os + + +@pytest.mark.parametrize( + "input, call_type", + [ + ({"messages": [{"role": "user", "content": "Hello world"}]}, "completion"), + ({"input": "Hello world"}, "embedding"), + ({"prompt": "Hello world"}, "image_generation"), + ], +) +def test_azure_extra_headers(input, call_type): + from litellm import embedding, image_generation + + http_client = Client() + + messages = [{"role": "user", "content": "Hello world"}] + with patch.object(http_client, "send", new=MagicMock()) as mock_client: + litellm.client_session = http_client + try: + if call_type == "completion": + func = completion + elif call_type == "embedding": + func = embedding + elif call_type == "image_generation": + func = image_generation + response = func( + model="azure/chatgpt-v-2", + api_base="https://openai-gpt-4-test-v-1.openai.azure.com", + api_version="2023-07-01-preview", + api_key="my-azure-api-key", + extra_headers={ + "Authorization": "my-bad-key", + "Ocp-Apim-Subscription-Key": "hello-world-testing", + }, + **input, + ) + print(response) + except Exception as e: + print(e) + + mock_client.assert_called() + + print(f"mock_client.call_args: {mock_client.call_args}") + request = mock_client.call_args[0][0] + print(request.method) # This will print 'POST' + print(request.url) # This will print the full URL + print(request.headers) # This will print the full URL + auth_header = request.headers.get("Authorization") + apim_key = request.headers.get("Ocp-Apim-Subscription-Key") + print(auth_header) + assert auth_header == "my-bad-key" + assert apim_key == "hello-world-testing" + + @pytest.mark.parametrize( "api_base, model, expected_endpoint", [ diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index f3cf8cb58..fdda7b171 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -786,19 +786,122 @@ def test_unmapped_vertex_anthropic_model(): assert "max_retries" not in optional_params -@pytest.mark.parametrize( - "tools, key", - [ - ([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"), - ([{"code_execution": {}}], "code_execution"), - ], -) -def test_vertex_tool_params(tools, key): +@pytest.mark.parametrize("provider", ["anthropic", "vertex_ai"]) +def test_anthropic_parallel_tool_calls(provider): + optional_params = get_optional_params( + model="claude-3-5-sonnet-v250@20241022", + custom_llm_provider=provider, + parallel_tool_calls=True, + ) + print(f"optional_params: {optional_params}") + assert optional_params["tool_choice"]["disable_parallel_tool_use"] is True + + +def test_anthropic_computer_tool_use(): + tools = [ + { + "type": "computer_20241022", + "function": { + "name": "computer", + "parameters": { + "display_height_px": 100, + "display_width_px": 100, + "display_number": 1, + }, + }, + } + ] optional_params = get_optional_params( - model="gemini-1.5-pro", + model="claude-3-5-sonnet-v250@20241022", + custom_llm_provider="anthropic", + tools=tools, + ) + assert optional_params["tools"][0]["type"] == "computer_20241022" + assert optional_params["tools"][0]["display_height_px"] == 100 + assert optional_params["tools"][0]["display_width_px"] == 100 + assert optional_params["tools"][0]["display_number"] == 1 + + +def test_vertex_schema_field(): + tools = [ + { + "type": "function", + "function": { + "name": "json", + "description": "Respond with a JSON object.", + "parameters": { + "type": "object", + "properties": { + "thinking": { + "type": "string", + "description": "Your internal thoughts on different problem details given the guidance.", + }, + "problems": { + "type": "array", + "items": { + "type": "object", + "properties": { + "icon": { + "type": "string", + "enum": [ + "BarChart2", + "Bell", + ], + "description": "The name of a Lucide icon to display", + }, + "color": { + "type": "string", + "description": "A Tailwind color class for the icon, e.g., 'text-red-500'", + }, + "problem": { + "type": "string", + "description": "The title of the problem being addressed, approximately 3-5 words.", + }, + "description": { + "type": "string", + "description": "A brief explanation of the problem, approximately 20 words.", + }, + "impacts": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of potential impacts or consequences of the problem, approximately 3 words each.", + }, + "automations": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of potential automations to address the problem, approximately 3-5 words each.", + }, + }, + "required": [ + "icon", + "color", + "problem", + "description", + "impacts", + "automations", + ], + "additionalProperties": False, + }, + "description": "Please generate problem cards that match this guidance.", + }, + }, + "required": ["thinking", "problems"], + "additionalProperties": False, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + } + ] + + optional_params = get_optional_params( + model="gemini-1.5-flash", custom_llm_provider="vertex_ai", tools=tools, ) print(optional_params) - assert optional_params["tools"][0][key] == {} + print(optional_params["tools"][0]["function_declarations"][0]) + assert ( + "$schema" + not in optional_params["tools"][0]["function_declarations"][0]["parameters"] + )