fix(vertex_httpx.py): add function calling support to httpx route

This commit is contained in:
Krrish Dholakia 2024-06-12 21:11:00 -07:00
parent afebf867f6
commit e60b0e96e4
6 changed files with 345 additions and 20 deletions

View file

@ -19,6 +19,10 @@ from litellm.types.llms.vertex_ai import (
PartType,
RequestBody,
GenerateContentResponseBody,
FunctionCallingConfig,
FunctionDeclaration,
Tools,
ToolConfig,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.utils import GenericStreamingChunk
@ -26,18 +30,203 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionResponseMessage,
)
class VertexGeminiConfig:
def __init__(self) -> None:
pass
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
def supports_system_message(self) -> bool:
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict]
) -> Optional[ToolConfig]:
if tool_choice == "none":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
elif tool_choice == "required":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
elif tool_choice == "auto":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
name = tool_choice.get("function", {}).get("name", "")
return ToolConfig(
functionCallingConfig=FunctionCallingConfig(
mode="ANY", allowed_function_names=[name]
)
)
else:
raise litellm.utils.UnsupportedParamsError(
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if (
param == "stream" and value is True
): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value
if param == "n":
optional_params["candidate_count"] = value
if param == "stop":
if isinstance(value, str):
optional_params["stop_sequences"] = [value]
elif isinstance(value, list):
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object": # type: ignore
optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
for tool in value:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
Tools(function_declarations=gtool_func_declarations)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Not all gemini models support system instructions
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return True
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
"""
return [
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
]
async def make_call(
@ -165,21 +354,37 @@ class VertexLLM(BaseLLM):
## GET MODEL ##
model_response.model = model
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
for idx, candidate in enumerate(completion_response["candidates"]):
if candidate.get("content", None) is None:
if "content" not in candidate:
continue
message = litellm.Message(
content=candidate["content"]["parts"][0]["text"],
role="assistant",
logprobs=None,
function_call=None,
tool_calls=None,
)
if "text" in candidate["content"]["parts"][0]:
content_str = candidate["content"]["parts"][0]["text"]
if "functionCall" in candidate["content"]["parts"][0]:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"]["name"],
arguments=json.dumps(
candidate["content"]["parts"][0]["functionCall"]["args"]
),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=f"call_{str(uuid.uuid4())}",
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=message,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
@ -402,8 +607,14 @@ class VertexLLM(BaseLLM):
messages.pop(idx)
system_instructions = SystemInstructions(parts=system_content_blocks)
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
data = RequestBody(system_instruction=system_instructions, contents=content)
if tools is not None:
data["tools"] = tools
if tool_choice is not None:
data["toolConfig"] = tool_choice
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",