mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* refactor: move gemini translation logic inside the transformation.py file easier to isolate the gemini translation logic * fix(gemini-transformation): support multiple tool calls in message body Merges https://github.com/BerriAI/litellm/pull/6487/files * test(test_vertex.py): add remaining tests from https://github.com/BerriAI/litellm/pull/6487 * fix(gemini-transformation): return tool calls for multiple tool calls * fix: support passing logprobs param for vertex + gemini * feat(vertex_ai): add logprobs support for gemini calls * fix(anthropic/chat/transformation.py): fix disable parallel tool use flag * fix: fix linting error * fix(_logging.py): log stacktrace information in json logs Closes https://github.com/BerriAI/litellm/issues/6497 * fix(utils.py): fix mem leak for async stream + completion Uses a global executor pool instead of creating a new thread on each request Fixes https://github.com/BerriAI/litellm/issues/6404 * fix(factory.py): handle tool call + content in assistant message for bedrock * fix: fix import * fix(factory.py): maintain support for content as a str in assistant response * fix: fix import * test: cleanup test * fix(vertex_and_google_ai_studio/): return none for content if no str value * test: retry flaky tests * (UI) Fix viewing members, keys in a team + added testing (#6514) * fix listing teams on ui * LiteLLM Minor Fixes & Improvements (10/28/2024) (#6475) * fix(anthropic/chat/transformation.py): support anthropic disable_parallel_tool_use param Fixes https://github.com/BerriAI/litellm/issues/6456 * feat(anthropic/chat/transformation.py): support anthropic computer tool use Closes https://github.com/BerriAI/litellm/issues/6427 * fix(vertex_ai/common_utils.py): parse out '$schema' when calling vertex ai Fixes issue when trying to call vertex from vercel sdk * fix(main.py): add 'extra_headers' support for azure on all translation endpoints Fixes https://github.com/BerriAI/litellm/issues/6465 * fix: fix linting errors * fix(transformation.py): handle no beta headers for anthropic * test: cleanup test * fix: fix linting error * fix: fix linting errors * fix: fix linting errors * fix(transformation.py): handle dummy tool call * fix(main.py): fix linting error * fix(azure.py): pass required param * LiteLLM Minor Fixes & Improvements (10/24/2024) (#6441) * fix(azure.py): handle /openai/deployment in azure api base * fix(factory.py): fix faulty anthropic tool result translation check Fixes https://github.com/BerriAI/litellm/issues/6422 * fix(gpt_transformation.py): add support for parallel_tool_calls to azure Fixes https://github.com/BerriAI/litellm/issues/6440 * fix(factory.py): support anthropic prompt caching for tool results * fix(vertex_ai/common_utils): don't pop non-null required field Fixes https://github.com/BerriAI/litellm/issues/6426 * feat(vertex_ai.py): support code_execution tool call for vertex ai + gemini Closes https://github.com/BerriAI/litellm/issues/6434 * build(model_prices_and_context_window.json): Add 'supports_assistant_prefill' for bedrock claude-3-5-sonnet v2 models Closes https://github.com/BerriAI/litellm/issues/6437 * fix(types/utils.py): fix linting * test: update test to include required fields * test: fix test * test: handle flaky test * test: remove e2e test - hitting gemini rate limits * Litellm dev 10 26 2024 (#6472) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * (Testing) Add unit testing for DualCache - ensure in memory cache is used when expected (#6471) * test test_dual_cache_get_set * unit testing for dual cache * fix async_set_cache_sadd * test_dual_cache_local_only * redis otel tracing + async support for latency routing (#6452) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * refactor: pass parent_otel_span for redis caching calls in router allows for more observability into what calls are causing latency issues * test: update tests with new params * refactor: ensure e2e otel tracing for router * refactor(router.py): add more otel tracing acrosss router catch all latency issues for router requests * fix: fix linting error * fix(router.py): fix linting error * fix: fix test * test: fix tests * fix(dual_cache.py): pass ttl to redis cache * fix: fix param * fix(dual_cache.py): set default value for parent_otel_span * fix(transformation.py): support 'response_format' for anthropic calls * fix(transformation.py): check for cache_control inside 'function' block * fix: fix linting error * fix: fix linting errors --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> * ui new build * Add retry strat (#6520) Signed-off-by: dbczumar <corey.zumar@databricks.com> * (fix) slack alerting - don't spam the failed cost tracking alert for the same model (#6543) * fix use failing_model as cache key for failed_tracking_alert * fix use standard logging payload for getting response cost * fix kwargs.get("response_cost") * fix getting response cost * (feat) add XAI ChatCompletion Support (#6373) * init commit for XAI * add full logic for xai chat completion * test_completion_xai * docs xAI * add xai/grok-beta * test_xai_chat_config_get_openai_compatible_provider_info * test_xai_chat_config_map_openai_params * add xai streaming test --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
469 lines
18 KiB
Python
469 lines
18 KiB
Python
import types
|
||
from typing import List, Literal, Optional, Tuple, Union
|
||
|
||
import litellm
|
||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||
from litellm.types.llms.anthropic import (
|
||
AllAnthropicToolsValues,
|
||
AnthropicComputerTool,
|
||
AnthropicHostedTools,
|
||
AnthropicMessageRequestBase,
|
||
AnthropicMessagesRequest,
|
||
AnthropicMessagesTool,
|
||
AnthropicMessagesToolChoice,
|
||
AnthropicSystemMessageContent,
|
||
)
|
||
from litellm.types.llms.openai import (
|
||
AllMessageValues,
|
||
ChatCompletionCachedContent,
|
||
ChatCompletionSystemMessage,
|
||
ChatCompletionToolParam,
|
||
ChatCompletionToolParamFunctionChunk,
|
||
)
|
||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||
|
||
from ..common_utils import AnthropicError
|
||
|
||
|
||
class AnthropicConfig:
|
||
"""
|
||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||
|
||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||
"""
|
||
|
||
max_tokens: Optional[int] = (
|
||
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||
)
|
||
stop_sequences: Optional[list] = None
|
||
temperature: Optional[int] = None
|
||
top_p: Optional[int] = None
|
||
top_k: Optional[int] = None
|
||
metadata: Optional[dict] = None
|
||
system: Optional[str] = None
|
||
|
||
def __init__(
|
||
self,
|
||
max_tokens: Optional[
|
||
int
|
||
] = 4096, # You can pass in a value yourself or use the default value 4096
|
||
stop_sequences: Optional[list] = None,
|
||
temperature: Optional[int] = None,
|
||
top_p: Optional[int] = None,
|
||
top_k: Optional[int] = None,
|
||
metadata: Optional[dict] = None,
|
||
system: Optional[str] = None,
|
||
) -> None:
|
||
locals_ = locals()
|
||
for key, value in locals_.items():
|
||
if key != "self" and value is not None:
|
||
setattr(self.__class__, key, value)
|
||
|
||
@classmethod
|
||
def get_config(cls):
|
||
return {
|
||
k: v
|
||
for k, v in cls.__dict__.items()
|
||
if not k.startswith("__")
|
||
and not isinstance(
|
||
v,
|
||
(
|
||
types.FunctionType,
|
||
types.BuiltinFunctionType,
|
||
classmethod,
|
||
staticmethod,
|
||
),
|
||
)
|
||
and v is not None
|
||
}
|
||
|
||
def get_supported_openai_params(self):
|
||
return [
|
||
"stream",
|
||
"stop",
|
||
"temperature",
|
||
"top_p",
|
||
"max_tokens",
|
||
"max_completion_tokens",
|
||
"tools",
|
||
"tool_choice",
|
||
"extra_headers",
|
||
"parallel_tool_calls",
|
||
"response_format",
|
||
]
|
||
|
||
def get_cache_control_headers(self) -> dict:
|
||
return {
|
||
"anthropic-version": "2023-06-01",
|
||
"anthropic-beta": "prompt-caching-2024-07-31",
|
||
}
|
||
|
||
def get_anthropic_headers(
|
||
self,
|
||
api_key: str,
|
||
anthropic_version: Optional[str] = None,
|
||
computer_tool_used: bool = False,
|
||
prompt_caching_set: bool = False,
|
||
) -> dict:
|
||
import json
|
||
|
||
betas = []
|
||
if prompt_caching_set:
|
||
betas.append("prompt-caching-2024-07-31")
|
||
if computer_tool_used:
|
||
betas.append("computer-use-2024-10-22")
|
||
headers = {
|
||
"anthropic-version": anthropic_version or "2023-06-01",
|
||
"x-api-key": api_key,
|
||
"accept": "application/json",
|
||
"content-type": "application/json",
|
||
}
|
||
if len(betas) > 0:
|
||
headers["anthropic-beta"] = ",".join(betas)
|
||
return headers
|
||
|
||
def _map_tool_choice(
|
||
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
|
||
) -> Optional[AnthropicMessagesToolChoice]:
|
||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||
if tool_choice == "auto":
|
||
_tool_choice = AnthropicMessagesToolChoice(
|
||
type="auto",
|
||
)
|
||
elif tool_choice == "required":
|
||
_tool_choice = AnthropicMessagesToolChoice(type="any")
|
||
elif isinstance(tool_choice, dict):
|
||
_tool_name = tool_choice.get("function", {}).get("name")
|
||
_tool_choice = AnthropicMessagesToolChoice(type="tool")
|
||
if _tool_name is not None:
|
||
_tool_choice["name"] = _tool_name
|
||
|
||
if parallel_tool_use is not None:
|
||
# Anthropic uses 'disable_parallel_tool_use' flag to determine if parallel tool use is allowed
|
||
# this is the inverse of the openai flag.
|
||
if _tool_choice is not None:
|
||
_tool_choice["disable_parallel_tool_use"] = not parallel_tool_use
|
||
else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag
|
||
_tool_choice = AnthropicMessagesToolChoice(
|
||
type="auto",
|
||
disable_parallel_tool_use=not parallel_tool_use,
|
||
)
|
||
return _tool_choice
|
||
|
||
def _map_tool_helper(
|
||
self, tool: ChatCompletionToolParam
|
||
) -> AllAnthropicToolsValues:
|
||
returned_tool: Optional[AllAnthropicToolsValues] = None
|
||
|
||
if tool["type"] == "function" or tool["type"] == "custom":
|
||
_tool = AnthropicMessagesTool(
|
||
name=tool["function"]["name"],
|
||
input_schema=tool["function"].get(
|
||
"parameters",
|
||
{
|
||
"type": "object",
|
||
"properties": {},
|
||
},
|
||
),
|
||
)
|
||
|
||
_description = tool["function"].get("description")
|
||
if _description is not None:
|
||
_tool["description"] = _description
|
||
|
||
returned_tool = _tool
|
||
|
||
elif tool["type"].startswith("computer_"):
|
||
## check if all required 'display_' params are given
|
||
if "parameters" not in tool["function"]:
|
||
raise ValueError("Missing required parameter: parameters")
|
||
|
||
_display_width_px: Optional[int] = tool["function"]["parameters"].get(
|
||
"display_width_px"
|
||
)
|
||
_display_height_px: Optional[int] = tool["function"]["parameters"].get(
|
||
"display_height_px"
|
||
)
|
||
if _display_width_px is None or _display_height_px is None:
|
||
raise ValueError(
|
||
"Missing required parameter: display_width_px or display_height_px"
|
||
)
|
||
|
||
_computer_tool = AnthropicComputerTool(
|
||
type=tool["type"],
|
||
name=tool["function"].get("name", "computer"),
|
||
display_width_px=_display_width_px,
|
||
display_height_px=_display_height_px,
|
||
)
|
||
|
||
_display_number = tool["function"]["parameters"].get("display_number")
|
||
if _display_number is not None:
|
||
_computer_tool["display_number"] = _display_number
|
||
|
||
returned_tool = _computer_tool
|
||
elif tool["type"].startswith("bash_") or tool["type"].startswith(
|
||
"text_editor_"
|
||
):
|
||
function_name = tool["function"].get("name")
|
||
if function_name is None:
|
||
raise ValueError("Missing required parameter: name")
|
||
|
||
returned_tool = AnthropicHostedTools(
|
||
type=tool["type"],
|
||
name=function_name,
|
||
)
|
||
if returned_tool is None:
|
||
raise ValueError(f"Unsupported tool type: {tool['type']}")
|
||
|
||
## check if cache_control is set in the tool
|
||
_cache_control = tool.get("cache_control", None)
|
||
_cache_control_function = tool.get("function", {}).get("cache_control", None)
|
||
if _cache_control is not None:
|
||
returned_tool["cache_control"] = _cache_control
|
||
elif _cache_control_function is not None and isinstance(
|
||
_cache_control_function, dict
|
||
):
|
||
returned_tool["cache_control"] = ChatCompletionCachedContent(
|
||
**_cache_control_function # type: ignore
|
||
)
|
||
|
||
return returned_tool
|
||
|
||
def _map_tools(self, tools: List) -> List[AllAnthropicToolsValues]:
|
||
anthropic_tools = []
|
||
for tool in tools:
|
||
if "input_schema" in tool: # assume in anthropic format
|
||
anthropic_tools.append(tool)
|
||
else: # assume openai tool call
|
||
new_tool = self._map_tool_helper(tool)
|
||
|
||
anthropic_tools.append(new_tool)
|
||
return anthropic_tools
|
||
|
||
def map_openai_params(
|
||
self,
|
||
non_default_params: dict,
|
||
optional_params: dict,
|
||
messages: Optional[List[AllMessageValues]] = None,
|
||
):
|
||
for param, value in non_default_params.items():
|
||
if param == "max_tokens":
|
||
optional_params["max_tokens"] = value
|
||
if param == "max_completion_tokens":
|
||
optional_params["max_tokens"] = value
|
||
if param == "tools":
|
||
optional_params["tools"] = self._map_tools(value)
|
||
if param == "tool_choice" or param == "parallel_tool_calls":
|
||
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
|
||
self._map_tool_choice(
|
||
tool_choice=non_default_params.get("tool_choice"),
|
||
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
||
)
|
||
)
|
||
|
||
if _tool_choice is not None:
|
||
optional_params["tool_choice"] = _tool_choice
|
||
if param == "stream" and value is True:
|
||
optional_params["stream"] = value
|
||
if param == "stop":
|
||
if isinstance(value, str):
|
||
if (
|
||
value == "\n"
|
||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||
continue
|
||
value = [value]
|
||
elif isinstance(value, list):
|
||
new_v = []
|
||
for v in value:
|
||
if (
|
||
v == "\n"
|
||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||
continue
|
||
new_v.append(v)
|
||
if len(new_v) > 0:
|
||
value = new_v
|
||
else:
|
||
continue
|
||
optional_params["stop_sequences"] = value
|
||
if param == "temperature":
|
||
optional_params["temperature"] = value
|
||
if param == "top_p":
|
||
optional_params["top_p"] = value
|
||
if param == "response_format" and isinstance(value, dict):
|
||
json_schema: Optional[dict] = None
|
||
if "response_schema" in value:
|
||
json_schema = value["response_schema"]
|
||
elif "json_schema" in value:
|
||
json_schema = value["json_schema"]["schema"]
|
||
"""
|
||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||
- You usually want to provide a single tool
|
||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||
"""
|
||
_tool_choice = None
|
||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||
|
||
_tool = AnthropicMessagesTool(
|
||
name="json_tool_call",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {"values": json_schema}, # type: ignore
|
||
},
|
||
)
|
||
|
||
optional_params["tools"] = [_tool]
|
||
optional_params["tool_choice"] = _tool_choice
|
||
optional_params["json_mode"] = True
|
||
|
||
## VALIDATE REQUEST
|
||
"""
|
||
Anthropic doesn't support tool calling without `tools=` param specified.
|
||
"""
|
||
if (
|
||
"tools" not in non_default_params
|
||
and messages is not None
|
||
and has_tool_call_blocks(messages)
|
||
):
|
||
if litellm.modify_params:
|
||
optional_params["tools"] = self._map_tools(
|
||
add_dummy_tool(custom_llm_provider="anthropic")
|
||
)
|
||
else:
|
||
raise litellm.UnsupportedParamsError(
|
||
message="Anthropic doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||
model="",
|
||
llm_provider="anthropic",
|
||
)
|
||
|
||
return optional_params
|
||
|
||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||
"""
|
||
Return if {"cache_control": ..} in message content block
|
||
|
||
Used to check if anthropic prompt caching headers need to be set.
|
||
"""
|
||
for message in messages:
|
||
if message.get("cache_control", None) is not None:
|
||
return True
|
||
_message_content = message.get("content")
|
||
if _message_content is not None and isinstance(_message_content, list):
|
||
for content in _message_content:
|
||
if "cache_control" in content:
|
||
return True
|
||
|
||
return False
|
||
|
||
def is_computer_tool_used(
|
||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||
) -> bool:
|
||
if tools is None:
|
||
return False
|
||
for tool in tools:
|
||
if "type" in tool and tool["type"].startswith("computer_"):
|
||
return True
|
||
return False
|
||
|
||
def translate_system_message(
|
||
self, messages: List[AllMessageValues]
|
||
) -> List[AnthropicSystemMessageContent]:
|
||
"""
|
||
Translate system message to anthropic format.
|
||
|
||
Removes system message from the original list and returns a new list of anthropic system message content.
|
||
"""
|
||
system_prompt_indices = []
|
||
anthropic_system_message_list: List[AnthropicSystemMessageContent] = []
|
||
for idx, message in enumerate(messages):
|
||
if message["role"] == "system":
|
||
valid_content: bool = False
|
||
system_message_block = ChatCompletionSystemMessage(**message)
|
||
if isinstance(system_message_block["content"], str):
|
||
anthropic_system_message_content = AnthropicSystemMessageContent(
|
||
type="text",
|
||
text=system_message_block["content"],
|
||
)
|
||
if "cache_control" in system_message_block:
|
||
anthropic_system_message_content["cache_control"] = (
|
||
system_message_block["cache_control"]
|
||
)
|
||
anthropic_system_message_list.append(
|
||
anthropic_system_message_content
|
||
)
|
||
valid_content = True
|
||
elif isinstance(message["content"], list):
|
||
for _content in message["content"]:
|
||
anthropic_system_message_content = (
|
||
AnthropicSystemMessageContent(
|
||
type=_content.get("type"),
|
||
text=_content.get("text"),
|
||
)
|
||
)
|
||
if "cache_control" in _content:
|
||
anthropic_system_message_content["cache_control"] = (
|
||
_content["cache_control"]
|
||
)
|
||
|
||
anthropic_system_message_list.append(
|
||
anthropic_system_message_content
|
||
)
|
||
valid_content = True
|
||
|
||
if valid_content:
|
||
system_prompt_indices.append(idx)
|
||
if len(system_prompt_indices) > 0:
|
||
for idx in reversed(system_prompt_indices):
|
||
messages.pop(idx)
|
||
|
||
return anthropic_system_message_list
|
||
|
||
def _transform_request(
|
||
self,
|
||
model: str,
|
||
messages: List[AllMessageValues],
|
||
optional_params: dict,
|
||
headers: dict,
|
||
_is_function_call: bool,
|
||
is_vertex_request: bool,
|
||
) -> dict:
|
||
"""
|
||
Translate messages to anthropic format.
|
||
"""
|
||
# Separate system prompt from rest of message
|
||
anthropic_system_message_list = self.translate_system_message(messages=messages)
|
||
# Handling anthropic API Prompt Caching
|
||
if len(anthropic_system_message_list) > 0:
|
||
optional_params["system"] = anthropic_system_message_list
|
||
# Format rest of message according to anthropic guidelines
|
||
try:
|
||
anthropic_messages = anthropic_messages_pt(
|
||
model=model,
|
||
messages=messages,
|
||
llm_provider="anthropic",
|
||
)
|
||
except Exception as e:
|
||
raise AnthropicError(
|
||
status_code=400,
|
||
message="{}\nReceived Messages={}".format(str(e), messages),
|
||
) # don't use verbose_logger.exception, if exception is raised
|
||
|
||
## Load Config
|
||
config = litellm.AnthropicConfig.get_config()
|
||
for k, v in config.items():
|
||
if (
|
||
k not in optional_params
|
||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||
optional_params[k] = v
|
||
|
||
## Handle Tool Calling
|
||
if "tools" in optional_params:
|
||
_is_function_call = True
|
||
|
||
data = {
|
||
"messages": anthropic_messages,
|
||
**optional_params,
|
||
}
|
||
if not is_vertex_request:
|
||
data["model"] = model
|
||
return data
|