diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index cf593369c4..dad4dbd8d5 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -12,6 +12,7 @@ from typing import ( Sequence, ) import litellm +import litellm.types from litellm.types.completion import ( ChatCompletionUserMessageParam, ChatCompletionSystemMessageParam, @@ -20,9 +21,12 @@ from litellm.types.completion import ( ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam, ) +import litellm.types.llms from litellm.types.llms.anthropic import * import uuid +import litellm.types.llms.vertex_ai + def default_pt(messages): return " ".join(message["content"] for message in messages) @@ -841,6 +845,175 @@ def anthropic_messages_pt_xml(messages: list): # ------------------------------------------------------------------------------ +def infer_protocol_value( + value: Any, +) -> Literal[ + "string_value", + "number_value", + "bool_value", + "struct_value", + "list_value", + "null_value", + "unknown", +]: + if value is None: + return "null_value" + if isinstance(value, int) or isinstance(value, float): + return "number_value" + if isinstance(value, str): + return "string_value" + if isinstance(value, bool): + return "bool_value" + if isinstance(value, dict): + return "struct_value" + if isinstance(value, list): + return "list_value" + + return "unknown" + + +def convert_to_gemini_tool_call_invoke( + tool_calls: list, +) -> List[litellm.types.llms.vertex_ai.PartType]: + """ + OpenAI tool invokes: + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + """ + """ + Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output + content { + role: "model" + parts [ + { + function_call { + name: "get_current_weather" + args { + fields { + key: "unit" + value { + string_value: "fahrenheit" + } + } + fields { + key: "predicted_temperature" + value { + number_value: 45 + } + } + fields { + key: "location" + value { + string_value: "Boston, MA" + } + } + } + }, + { + function_call { + name: "get_current_weather" + args { + fields { + key: "location" + value { + string_value: "San Francisco" + } + } + } + } + } + ] + } + """ + + """ + - json.load the arguments + - iterate through arguments -> create a FunctionCallArgs for each field + """ + try: + _parts_list: List[litellm.types.llms.vertex_ai.PartType] = [] + for tool in tool_calls: + if "function" in tool: + name = tool["function"].get("name", "") + arguments = tool["function"].get("arguments", "") + arguments_dict = json.loads(arguments) + for k, v in arguments_dict.items(): + inferred_protocol_value = infer_protocol_value(value=v) + _field = litellm.types.llms.vertex_ai.Field( + key=k, value={inferred_protocol_value: v} + ) + _fields = litellm.types.llms.vertex_ai.FunctionCallArgs( + fields=_field + ) + function_call = litellm.types.llms.vertex_ai.FunctionCall( + name=name, + args=_fields, + ) + _parts_list.append( + litellm.types.llms.vertex_ai.PartType(function_call=function_call) + ) + return _parts_list + except Exception as e: + raise Exception( + "Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format( + tool_calls, str(e) + ) + ) + + +def convert_to_gemini_tool_call_result( + message: dict, +) -> litellm.types.llms.vertex_ai.PartType: + """ + OpenAI message with a tool result looks like: + { + "tool_call_id": "tool_1", + "role": "tool", + "name": "get_current_weather", + "content": "function result goes here", + }, + + OpenAI message with a function call result looks like: + { + "role": "function", + "name": "get_current_weather", + "content": "function result goes here", + } + """ + content = message.get("content", "") + name = message.get("name", "") + + # We can't determine from openai message format whether it's a successful or + # error call result so default to the successful result template + inferred_content_value = infer_protocol_value(value=content) + + _field = litellm.types.llms.vertex_ai.Field( + key="content", value={inferred_content_value: content} + ) + + _function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field) + + _function_response = litellm.types.llms.vertex_ai.FunctionResponse( + name=name, response=_function_call_args + ) + + _part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response) + + return _part + + def convert_to_anthropic_tool_result(message: dict) -> dict: """ OpenAI message with a tool result looks like: @@ -1513,7 +1686,7 @@ def prompt_factory( elif custom_llm_provider == "clarifai": if "claude" in model: return anthropic_pt(messages=messages) - + elif custom_llm_provider == "perplexity": for message in messages: message.pop("name", None) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 84fec734fd..58c512d56e 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -3,10 +3,15 @@ import json from enum import Enum import requests # type: ignore import time -from typing import Callable, Optional, Union, List +from typing import Callable, Optional, Union, List, Literal from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import litellm, uuid import httpx, inspect # type: ignore +from litellm.types.llms.vertex_ai import * +from litellm.llms.prompt_templates.factory import ( + convert_to_gemini_tool_call_result, + convert_to_gemini_tool_call_invoke, +) class VertexAIError(Exception): @@ -283,6 +288,129 @@ def _load_image_from_url(image_url: str): return Image.from_bytes(data=image_bytes) +def _convert_gemini_role(role: str) -> Literal["user", "model"]: + if role == "user": + return "user" + else: + return "model" + + +def _process_gemini_image(image_url: str): + try: + import vertexai + except: + raise VertexAIError( + status_code=400, + message="vertexai import failed please run `pip install google-cloud-aiplatform`", + ) + from vertexai.preview.generative_models import Part + + if "gs://" in image_url: + # Case 1: Images with Cloud Storage URIs + # The supported MIME types for images include image/png and image/jpeg. + part_mime = "image/png" if "png" in image_url else "image/jpeg" + google_clooud_part = Part.from_uri(image_url, mime_type=part_mime) + return google_clooud_part + elif "https:/" in image_url: + # Case 2: Images with direct links + image = _load_image_from_url(image_url) + return image + elif ".mp4" in image_url and "gs://" in image_url: + # Case 3: Videos with Cloud Storage URIs + part_mime = "video/mp4" + google_clooud_part = Part.from_uri(image_url, mime_type=part_mime) + return google_clooud_part + elif "base64" in image_url: + # Case 4: Images with base64 encoding + import base64, re + + # base 64 is passed as data:image/jpeg;base64, + image_metadata, img_without_base_64 = image_url.split(",") + + # read mime_type from img_without_base_64=data:image/jpeg;base64 + # Extract MIME type using regular expression + mime_type_match = re.match(r"data:(.*?);base64", image_metadata) + + if mime_type_match: + mime_type = mime_type_match.group(1) + else: + mime_type = "image/jpeg" + decoded_img = base64.b64decode(img_without_base_64) + processed_image = Part.from_data(data=decoded_img, mime_type=mime_type) + return processed_image + + +def _gemini_convert_messages_text(messages: list) -> List[ContentType]: + """ + Converts given messages from OpenAI format to Gemini format + + - Parts must be iterable + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + user_message_types = {"user", "system"} + contents: List[ContentType] = [] + + msg_i = 0 + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: + if isinstance(messages[msg_i]["content"], list): + _parts: List[PartType] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + user_content.extend(_parts) + else: + _part = PartType(text=messages[msg_i]["content"]) + user_content.append(_part) + + msg_i += 1 + + if user_content: + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(PartType(text=assistant_text)) + if messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke convertion + assistant_content.extend( + convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) + ) + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + if messages[msg_i]["role"] == "tool": + _part = convert_to_gemini_tool_call_result(messages[msg_i]) + contents.append(ContentType(parts=[_part])) # type: ignore + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + + return contents + + def _gemini_vision_convert_messages(messages: list): """ Converts given messages for GPT-4 Vision to Gemini format. @@ -574,11 +702,9 @@ def completion( print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") tools = optional_params.pop("tools", None) - prompt, images = _gemini_vision_convert_messages(messages=messages) - content = [prompt] + images + content = _gemini_convert_messages_text(messages=messages) stream = optional_params.pop("stream", False) if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" logging_obj.pre_call( input=prompt, @@ -590,7 +716,7 @@ def completion( ) model_response = llm_model.generate_content( - contents=content, + contents={"content": content}, generation_config=optional_params, safety_settings=safety_settings, stream=True, diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt index fd9557c9b5..6966d252de 100644 --- a/litellm/tests/log.txt +++ b/litellm/tests/log.txt @@ -1,15 +1,21 @@ ============================= test session starts ============================== -platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0 -rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests -plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1 +platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 +rootdir: /Users/krrishdholakia/Documents/litellm +configfile: pyproject.toml +plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 asyncio: mode=Mode.STRICT collected 1 item -test_router_timeout.py . [100%] +test_amazing_vertex_completion.py Success: model=gemini-1.5-pro-preview-0514 in model_cost_map +prompt_tokens=88; completion_tokens=31 +Returned custom cost for model=gemini-1.5-pro-preview-0514 - prompt_tokens_cost_usd_dollar: 5.5e-05, completion_tokens_cost_usd_dollar: 5.8125e-05 +final cost: 0.000113125; prompt_tokens_cost_usd_dollar: 5.5e-05; completion_tokens_cost_usd_dollar: 5.8125e-05 +success callbacks: [] +. [100%] =============================== warnings summary =============================== -../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings - /opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings + /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) ../proxy/_types.py:255 @@ -68,9 +74,5 @@ test_router_timeout.py . [100%] /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: -test_router_timeout.py::test_router_timeouts_bedrock - /opt/homebrew/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content. - warnings.warn(message, DeprecationWarning) - -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -======================== 1 passed, 40 warnings in 0.99s ======================== +======================== 1 passed, 39 warnings in 2.92s ======================== diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 06215d0d1a..518b5537fb 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -16,6 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests import json import os import tempfile +from litellm.llms.vertex_ai import _gemini_convert_messages_text litellm.num_retries = 3 litellm.cache = None @@ -98,7 +99,7 @@ def load_vertex_ai_credentials(): @pytest.mark.asyncio -async def get_response(): +async def test_get_response(): load_vertex_ai_credentials() prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n' try: @@ -589,35 +590,73 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.asyncio async def test_gemini_pro_function_calling(sync_mode): try: load_vertex_ai_credentials() - data = { - "model": "vertex_ai/gemini-pro", - "messages": [ - { - "role": "user", - "content": "Call the submit_cities function with San Francisco and New York", - } - ], - "tools": [ - { - "type": "function", - "function": { - "name": "submit_cities", - "description": "Submits a list of cities", - "parameters": { - "type": "object", - "properties": { - "cities": {"type": "array", "items": {"type": "string"}} - }, - "required": ["cities"], + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + # Assistant replies with a tool call + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "index": 0, + "function": { + "name": "get_weather", + "arguments": '{"location":"San Francisco, CA"}', }, + } + ], + }, + # The result of the tool call is added to the history + { + "role": "tool", + "tool_call_id": "call_123", + "name": "get_weather", + "content": "27 degrees celsius and clear in San Francisco, CA", + }, + # Now the assistant can reply with the result of the tool call. + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], }, - } - ], + }, + } + ] + + data = { + "model": "vertex_ai/gemini-1.5-pro-preview-0514", + "messages": messages, + "tools": tools, } if sync_mode: response = litellm.completion(**data) @@ -712,7 +751,7 @@ async def test_gemini_pro_async_function_calling(): "type": "function", "function": { "name": "get_current_weather", - "description": "Get the current weather in a given location", + "description": "Get the current weather in a given location. Response with a prediction on temperature as well.", "parameters": { "type": "object", "properties": { @@ -724,8 +763,9 @@ async def test_gemini_pro_async_function_calling(): "type": "string", "enum": ["celsius", "fahrenheit"], }, + "predicted_temperature": {"type": "integer"}, }, - "required": ["location"], + "required": ["location", "predicted_temperature"], }, }, } @@ -733,7 +773,7 @@ async def test_gemini_pro_async_function_calling(): messages = [ { "role": "user", - "content": "What's the weather like in Boston today in fahrenheit?", + "content": "What's the weather like in Boston today in fahrenheit, give me a prediction?", } ] completion = await litellm.acompletion( @@ -742,8 +782,10 @@ async def test_gemini_pro_async_function_calling(): print(f"completion: {completion}") assert completion.choices[0].message.content is None assert len(completion.choices[0].message.tool_calls) == 1 - except litellm.APIError as e: - pass + + raise Exception + # except litellm.APIError as e: + # pass except litellm.RateLimitError as e: pass except Exception as e: @@ -893,3 +935,46 @@ async def test_vertexai_aembedding(): # traceback.print_exc() # raise e # test_gemini_pro_vision_async() + + +def test_prompt_factory(): + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + # Assistant replies with a tool call + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "index": 0, + "function": { + "name": "get_weather", + "arguments": '{"location":"San Francisco, CA"}', + }, + } + ], + }, + # The result of the tool call is added to the history + { + "role": "tool", + "tool_call_id": "call_123", + "name": "get_weather", + "content": "27 degrees celsius and clear in San Francisco, CA", + }, + # Now the assistant can reply with the result of the tool call. + ] + + translated_messages = _gemini_convert_messages_text(messages=messages) + + print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages") + raise Exception diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py new file mode 100644 index 0000000000..cced185de0 --- /dev/null +++ b/litellm/types/llms/vertex_ai.py @@ -0,0 +1,53 @@ +from typing import TypedDict, Any, Union, Optional, List, Literal +import json +from typing_extensions import ( + Self, + Protocol, + TypeGuard, + override, + get_origin, + runtime_checkable, + Required, +) + + +class Field(TypedDict): + key: str + value: dict[str, Any] + + +class FunctionCallArgs(TypedDict): + fields: Field + + +class FunctionResponse(TypedDict): + name: str + response: FunctionCallArgs + + +class FunctionCall(TypedDict): + name: str + args: FunctionCallArgs + + +class FileDataType(TypedDict): + mime_type: str + file_uri: str # the cloud storage uri of storing this file + + +class BlobType(TypedDict): + mime_type: Required[str] + data: Required[bytes] + + +class PartType(TypedDict, total=False): + text: str + inline_data: BlobType + file_data: FileDataType + function_call: FunctionCall + function_response: FunctionResponse + + +class ContentType(TypedDict, total=False): + role: Literal["user", "model"] + parts: Required[List[PartType]]