forked from phoenix/litellm-mirror
fix(vertex_httpx.py): add function calling support to httpx route
This commit is contained in:
parent
995631bd39
commit
c426d75e91
6 changed files with 345 additions and 20 deletions
|
@ -767,6 +767,7 @@ from .llms.nlp_cloud import NLPCloudConfig
|
|||
from .llms.aleph_alpha import AlephAlphaConfig
|
||||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_ai import VertexAIConfig
|
||||
from .llms.vertex_httpx import VertexGeminiConfig
|
||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||
from .llms.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
|
|
|
@ -19,6 +19,10 @@ from litellm.types.llms.vertex_ai import (
|
|||
PartType,
|
||||
RequestBody,
|
||||
GenerateContentResponseBody,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
Tools,
|
||||
ToolConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
|
@ -26,18 +30,203 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionUsageBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
)
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
|
||||
def supports_system_message(self) -> bool:
|
||||
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||
|
||||
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||
|
||||
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||
|
||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||
|
||||
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||
|
||||
- `candidate_count` (int): Number of generated responses to return.
|
||||
|
||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||
|
||||
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||
|
||||
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
response_mime_type: Optional[str] = None
|
||||
candidate_count: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict]
|
||||
) -> Optional[ToolConfig]:
|
||||
if tool_choice == "none":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
|
||||
elif tool_choice == "required":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
|
||||
elif tool_choice == "auto":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
name = tool_choice.get("function", {}).get("name", "")
|
||||
return ToolConfig(
|
||||
functionCallingConfig=FunctionCallingConfig(
|
||||
mode="ANY", allowed_function_names=[name]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if (
|
||||
param == "stream" and value is True
|
||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||
optional_params["stream"] = value
|
||||
if param == "n":
|
||||
optional_params["candidate_count"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
optional_params["stop_sequences"] = [value]
|
||||
elif isinstance(value, list):
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_output_tokens"] = value
|
||||
if param == "response_format" and value["type"] == "json_object": # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
for tool in value:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
Tools(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Not all gemini models support system instructions
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return True
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||
"""
|
||||
return [
|
||||
"europe-central2",
|
||||
"europe-north1",
|
||||
"europe-southwest1",
|
||||
"europe-west1",
|
||||
"europe-west2",
|
||||
"europe-west3",
|
||||
"europe-west4",
|
||||
"europe-west6",
|
||||
"europe-west8",
|
||||
"europe-west9",
|
||||
]
|
||||
|
||||
|
||||
async def make_call(
|
||||
|
@ -165,21 +354,37 @@ class VertexLLM(BaseLLM):
|
|||
## GET MODEL ##
|
||||
model_response.model = model
|
||||
## GET TEXT ##
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||
if candidate.get("content", None) is None:
|
||||
if "content" not in candidate:
|
||||
continue
|
||||
|
||||
message = litellm.Message(
|
||||
content=candidate["content"]["parts"][0]["text"],
|
||||
role="assistant",
|
||||
logprobs=None,
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
if "text" in candidate["content"]["parts"][0]:
|
||||
content_str = candidate["content"]["parts"][0]["text"]
|
||||
|
||||
if "functionCall" in candidate["content"]["parts"][0]:
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=candidate["content"]["parts"][0]["functionCall"]["name"],
|
||||
arguments=json.dumps(
|
||||
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||
),
|
||||
)
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=f"call_{str(uuid.uuid4())}",
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
|
||||
chat_completion_message["content"] = content_str
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
choice = litellm.Choices(
|
||||
finish_reason=candidate.get("finishReason", "stop"),
|
||||
index=candidate.get("index", idx),
|
||||
message=message,
|
||||
message=chat_completion_message, # type: ignore
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
@ -402,8 +607,14 @@ class VertexLLM(BaseLLM):
|
|||
messages.pop(idx)
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
|
||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
data["toolConfig"] = tool_choice
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
|
|
|
@ -615,9 +615,76 @@ def test_gemini_pro_vision_base64():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_function_calling(sync_mode):
|
||||
async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
# User asks for their name and weather in San Francisco
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, what is your name and can you tell me the weather?",
|
||||
},
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": "{}/gemini-1.5-pro-preview-0514".format(provider),
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "required",
|
||||
}
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
else:
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
assert response.choices[0].message.tool_calls[0].function.arguments is not None
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "429 Quota exceeded" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_function_calling(provider, sync_mode):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
@ -679,7 +746,7 @@ async def test_gemini_pro_function_calling(sync_mode):
|
|||
]
|
||||
|
||||
data = {
|
||||
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
|
||||
"model": "{}/gemini-1.5-pro-preview-0514".format(provider),
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}
|
||||
|
|
|
@ -49,6 +49,24 @@ class PartType(TypedDict, total=False):
|
|||
function_response: FunctionResponse
|
||||
|
||||
|
||||
class HttpxFunctionCall(TypedDict):
|
||||
name: str
|
||||
args: dict
|
||||
|
||||
|
||||
class HttpxPartType(TypedDict, total=False):
|
||||
text: str
|
||||
inline_data: BlobType
|
||||
file_data: FileDataType
|
||||
functionCall: HttpxFunctionCall
|
||||
function_response: FunctionResponse
|
||||
|
||||
|
||||
class HttpxContentType(TypedDict, total=False):
|
||||
role: Literal["user", "model"]
|
||||
parts: Required[List[HttpxPartType]]
|
||||
|
||||
|
||||
class ContentType(TypedDict, total=False):
|
||||
role: Literal["user", "model"]
|
||||
parts: Required[List[PartType]]
|
||||
|
@ -128,11 +146,19 @@ class GenerationConfig(TypedDict, total=False):
|
|||
response_mime_type: Literal["text/plain", "application/json"]
|
||||
|
||||
|
||||
class Tools(TypedDict):
|
||||
function_declarations: List[FunctionDeclaration]
|
||||
|
||||
|
||||
class ToolConfig(TypedDict):
|
||||
functionCallingConfig: FunctionCallingConfig
|
||||
|
||||
|
||||
class RequestBody(TypedDict, total=False):
|
||||
contents: Required[List[ContentType]]
|
||||
system_instruction: SystemInstructions
|
||||
tools: FunctionDeclaration
|
||||
tool_config: FunctionCallingConfig
|
||||
tools: Tools
|
||||
toolConfig: ToolConfig
|
||||
safety_settings: SafetSettingsConfig
|
||||
generation_config: GenerationConfig
|
||||
|
||||
|
@ -176,7 +202,7 @@ class GroundingMetadata(TypedDict, total=False):
|
|||
|
||||
class Candidates(TypedDict, total=False):
|
||||
index: int
|
||||
content: ContentType
|
||||
content: HttpxContentType
|
||||
finishReason: Literal[
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
"STOP",
|
||||
|
|
|
@ -5386,6 +5386,16 @@ def get_optional_params(
|
|||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexGeminiConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||
):
|
||||
|
|
10
log.txt
Normal file
10
log.txt
Normal file
|
@ -0,0 +1,10 @@
|
|||
============================= test session starts ==============================
|
||||
platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/bin/python3.11
|
||||
cachedir: .pytest_cache
|
||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||
configfile: pyproject.toml
|
||||
plugins: logfire-0.35.0, asyncio-0.23.6, mock-3.14.0, anyio-4.2.0
|
||||
asyncio: mode=Mode.STRICT
|
||||
collecting ... collected 0 items
|
||||
|
||||
============================ no tests ran in 0.00s =============================
|
Loading…
Add table
Add a link
Reference in a new issue