From f79365df6e1796f52929a165f53085f43710e31c Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 1 Nov 2024 23:14:32 +0400 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (10/30/2024) (#6519) * refactor: move gemini translation logic inside the transformation.py file easier to isolate the gemini translation logic * fix(gemini-transformation): support multiple tool calls in message body Merges https://github.com/BerriAI/litellm/pull/6487/files * test(test_vertex.py): add remaining tests from https://github.com/BerriAI/litellm/pull/6487 * fix(gemini-transformation): return tool calls for multiple tool calls * fix: support passing logprobs param for vertex + gemini * feat(vertex_ai): add logprobs support for gemini calls * fix(anthropic/chat/transformation.py): fix disable parallel tool use flag * fix: fix linting error * fix(_logging.py): log stacktrace information in json logs Closes https://github.com/BerriAI/litellm/issues/6497 * fix(utils.py): fix mem leak for async stream + completion Uses a global executor pool instead of creating a new thread on each request Fixes https://github.com/BerriAI/litellm/issues/6404 * fix(factory.py): handle tool call + content in assistant message for bedrock * fix: fix import * fix(factory.py): maintain support for content as a str in assistant response * fix: fix import * test: cleanup test * fix(vertex_and_google_ai_studio/): return none for content if no str value * test: retry flaky tests * (UI) Fix viewing members, keys in a team + added testing (#6514) * fix listing teams on ui * LiteLLM Minor Fixes & Improvements (10/28/2024) (#6475) * fix(anthropic/chat/transformation.py): support anthropic disable_parallel_tool_use param Fixes https://github.com/BerriAI/litellm/issues/6456 * feat(anthropic/chat/transformation.py): support anthropic computer tool use Closes https://github.com/BerriAI/litellm/issues/6427 * fix(vertex_ai/common_utils.py): parse out '$schema' when calling vertex ai Fixes issue when trying to call vertex from vercel sdk * fix(main.py): add 'extra_headers' support for azure on all translation endpoints Fixes https://github.com/BerriAI/litellm/issues/6465 * fix: fix linting errors * fix(transformation.py): handle no beta headers for anthropic * test: cleanup test * fix: fix linting error * fix: fix linting errors * fix: fix linting errors * fix(transformation.py): handle dummy tool call * fix(main.py): fix linting error * fix(azure.py): pass required param * LiteLLM Minor Fixes & Improvements (10/24/2024) (#6441) * fix(azure.py): handle /openai/deployment in azure api base * fix(factory.py): fix faulty anthropic tool result translation check Fixes https://github.com/BerriAI/litellm/issues/6422 * fix(gpt_transformation.py): add support for parallel_tool_calls to azure Fixes https://github.com/BerriAI/litellm/issues/6440 * fix(factory.py): support anthropic prompt caching for tool results * fix(vertex_ai/common_utils): don't pop non-null required field Fixes https://github.com/BerriAI/litellm/issues/6426 * feat(vertex_ai.py): support code_execution tool call for vertex ai + gemini Closes https://github.com/BerriAI/litellm/issues/6434 * build(model_prices_and_context_window.json): Add 'supports_assistant_prefill' for bedrock claude-3-5-sonnet v2 models Closes https://github.com/BerriAI/litellm/issues/6437 * fix(types/utils.py): fix linting * test: update test to include required fields * test: fix test * test: handle flaky test * test: remove e2e test - hitting gemini rate limits * Litellm dev 10 26 2024 (#6472) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * (Testing) Add unit testing for DualCache - ensure in memory cache is used when expected (#6471) * test test_dual_cache_get_set * unit testing for dual cache * fix async_set_cache_sadd * test_dual_cache_local_only * redis otel tracing + async support for latency routing (#6452) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * refactor: pass parent_otel_span for redis caching calls in router allows for more observability into what calls are causing latency issues * test: update tests with new params * refactor: ensure e2e otel tracing for router * refactor(router.py): add more otel tracing acrosss router catch all latency issues for router requests * fix: fix linting error * fix(router.py): fix linting error * fix: fix test * test: fix tests * fix(dual_cache.py): pass ttl to redis cache * fix: fix param * fix(dual_cache.py): set default value for parent_otel_span * fix(transformation.py): support 'response_format' for anthropic calls * fix(transformation.py): check for cache_control inside 'function' block * fix: fix linting error * fix: fix linting errors --------- Co-authored-by: Ishaan Jaff --------- Co-authored-by: Krish Dholakia * ui new build * Add retry strat (#6520) Signed-off-by: dbczumar * (fix) slack alerting - don't spam the failed cost tracking alert for the same model (#6543) * fix use failing_model as cache key for failed_tracking_alert * fix use standard logging payload for getting response cost * fix kwargs.get("response_cost") * fix getting response cost * (feat) add XAI ChatCompletion Support (#6373) * init commit for XAI * add full logic for xai chat completion * test_completion_xai * docs xAI * add xai/grok-beta * test_xai_chat_config_get_openai_compatible_provider_info * test_xai_chat_config_map_openai_params * add xai streaming test --------- Signed-off-by: dbczumar Co-authored-by: Ishaan Jaff Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> --- .gitignore | 1 + litellm/_logging.py | 3 + litellm/llms/anthropic/chat/transformation.py | 14 +- litellm/llms/prompt_templates/factory.py | 16 +- .../context_caching/transformation.py | 4 +- .../gemini/transformation.py | 206 +++- .../vertex_and_google_ai_studio_gemini.py | 605 ++++++----- .../vertex_ai_non_gemini.py | 369 +------ litellm/main.py | 1 - litellm/proxy/_new_secret_config.yaml | 5 +- .../pass_through_endpoints/success_handler.py | 4 +- litellm/tests/test_langtrace.py | 33 - litellm/types/llms/openai.py | 4 +- litellm/types/llms/vertex_ai.py | 18 + litellm/utils.py | 15 +- tests/llm_translation/test_optional_params.py | 2 +- tests/llm_translation/test_vertex.py | 941 ++++++++++++++++++ .../test_amazing_vertex_completion.py | 5 +- .../local_testing/test_bedrock_completion.py | 5 +- tests/local_testing/test_function_calling.py | 13 +- .../local_testing/test_key_generate_prisma.py | 2 + tests/local_testing/test_mem_leak.py | 243 +++++ tests/local_testing/test_prompt_factory.py | 2 +- tests/local_testing/test_user_api_key_auth.py | 40 + 24 files changed, 1851 insertions(+), 700 deletions(-) delete mode 100644 litellm/tests/test_langtrace.py create mode 100644 tests/local_testing/test_mem_leak.py diff --git a/.gitignore b/.gitignore index a24bd4920..e8e8aed2b 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ litellm/tests/log.txt litellm/tests/langfuse.log litellm/tests/langfuse.log litellm/proxy/google-cloud-sdk/* +tests/llm_translation/log.txt diff --git a/litellm/_logging.py b/litellm/_logging.py index ef41a586f..daa1a1dd2 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -35,6 +35,9 @@ class JsonFormatter(Formatter): "timestamp": self.formatTime(record), } + if record.exc_info: + json_record["stacktrace"] = self.formatException(record.exc_info) + return json.dumps(json_record) diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 4953a393a..ec3285473 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -123,7 +123,7 @@ class AnthropicConfig: return headers def _map_tool_choice( - self, tool_choice: Optional[str], disable_parallel_tool_use: Optional[bool] + self, tool_choice: Optional[str], parallel_tool_use: Optional[bool] ) -> Optional[AnthropicMessagesToolChoice]: _tool_choice: Optional[AnthropicMessagesToolChoice] = None if tool_choice == "auto": @@ -138,13 +138,15 @@ class AnthropicConfig: if _tool_name is not None: _tool_choice["name"] = _tool_name - if disable_parallel_tool_use is not None: + if parallel_tool_use is not None: + # Anthropic uses 'disable_parallel_tool_use' flag to determine if parallel tool use is allowed + # this is the inverse of the openai flag. if _tool_choice is not None: - _tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use + _tool_choice["disable_parallel_tool_use"] = not parallel_tool_use else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag _tool_choice = AnthropicMessagesToolChoice( type="auto", - disable_parallel_tool_use=disable_parallel_tool_use, + disable_parallel_tool_use=not parallel_tool_use, ) return _tool_choice @@ -255,9 +257,7 @@ class AnthropicConfig: _tool_choice: Optional[AnthropicMessagesToolChoice] = ( self._map_tool_choice( tool_choice=non_default_params.get("tool_choice"), - disable_parallel_tool_use=non_default_params.get( - "parallel_tool_calls" - ), + parallel_tool_use=non_default_params.get("parallel_tool_calls"), ) ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 7a78e8c09..aee304760 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2552,18 +2552,20 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915 BedrockContentBlock(image=assistants_part) # type: ignore ) assistant_content.extend(assistants_parts) - elif messages[msg_i].get( - "tool_calls", [] - ): # support assistant tool invoke convertion - assistant_content.extend( - _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"]) - ) - else: + elif messages[msg_i].get("content", None) is not None and isinstance( + messages[msg_i]["content"], str + ): assistant_text = ( messages[msg_i].get("content") or "" ) # either string or none if assistant_text: assistant_content.append(BedrockContentBlock(text=assistant_text)) + if messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke convertion [TODO]: + assistant_content.extend( + _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"]) + ) msg_i += 1 diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py index 86fb757b7..8caa112ea 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py @@ -11,9 +11,9 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc from litellm.utils import is_cached_message from ..common_utils import VertexAIError, get_supports_system_message -from ..gemini.transformation import _transform_system_message -from ..gemini.vertex_and_google_ai_studio_gemini import ( +from ..gemini.transformation import ( _gemini_convert_messages_with_history, + _transform_system_message, ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py index 075c0d169..66ab07674 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py @@ -4,14 +4,34 @@ Transformation logic from OpenAI format to Gemini format. Why separate file? Make it easy to see how transformation works """ -from typing import List, Literal, Optional, Tuple, Union +import os +from typing import List, Literal, Optional, Tuple, Union, cast import httpx +from pydantic import BaseModel import litellm +from litellm._logging import verbose_logger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.prompt_templates.factory import response_schema_prompt -from litellm.types.llms.openai import AllMessageValues +from litellm.llms.prompt_templates.factory import ( + convert_to_anthropic_image_obj, + convert_to_gemini_tool_call_invoke, + convert_to_gemini_tool_call_result, + response_schema_prompt, +) +from litellm.types.files import ( + get_file_mime_type_for_file_type, + get_file_type_from_extension, + is_gemini_1_5_accepted_file_type, + is_video_file_type, +) +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionImageObject, + ChatCompletionTextObject, +) +from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import ( GenerationConfig, PartType, @@ -21,9 +41,185 @@ from litellm.types.llms.vertex_ai import ( ToolConfig, Tools, ) +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage -from ..common_utils import get_supports_response_schema, get_supports_system_message -from ..vertex_ai_non_gemini import _gemini_convert_messages_with_history +from ..common_utils import ( + _check_text_in_content, + get_supports_response_schema, + get_supports_system_message, +) + + +def _process_gemini_image(image_url: str) -> PartType: + try: + # GCS URIs + if "gs://" in image_url: + # Figure out file type + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" + + file_type = get_file_type_from_extension(extension) + + # Validate the file type is supported by Gemini + if not is_gemini_1_5_accepted_file_type(file_type): + raise Exception(f"File type not supported by gemini - {file_type}") + + mime_type = get_file_mime_type_for_file_type(file_type) + file_data = FileDataType(mime_type=mime_type, file_uri=image_url) + + return PartType(file_data=file_data) + + # Direct links + elif "https:/" in image_url or "base64" in image_url: + image = convert_to_anthropic_image_obj(image_url) + _blob = BlobType(data=image["data"], mime_type=image["media_type"]) + return PartType(inline_data=_blob) + raise Exception("Invalid image received - {}".format(image_url)) + except Exception as e: + raise e + + +def _gemini_convert_messages_with_history( # noqa: PLR0915 + messages: List[AllMessageValues], +) -> List[ContentType]: + """ + Converts given messages from OpenAI format to Gemini format + + - Parts must be iterable + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + user_message_types = {"user", "system"} + contents: List[ContentType] = [] + + last_message_with_tool_calls = None + + msg_i = 0 + tool_call_responses = [] + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types + ): + _message_content = messages[msg_i].get("content") + if _message_content is not None and isinstance(_message_content, list): + _parts: List[PartType] = [] + for element in _message_content: + if ( + element["type"] == "text" + and "text" in element + and len(element["text"]) > 0 + ): + element = cast(ChatCompletionTextObject, element) + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + element = cast(ChatCompletionImageObject, element) + img_element = element + if isinstance(img_element["image_url"], dict): + image_url = img_element["image_url"]["url"] + else: + image_url = img_element["image_url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) + user_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and len(_message_content) > 0 + ): + _part = PartType(text=_message_content) + user_content.append(_part) + + msg_i += 1 + + if user_content: + """ + check that user_content has 'text' parameter. + - Known Vertex Error: Unable to submit request because it must have a text parameter. + - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515 + """ + has_text_in_content = _check_text_in_content(user_content) + if has_text_in_content is False: + verbose_logger.warning( + "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515" + ) + user_content.append( + PartType(text=" ") + ) # add a blank text, to ensure Gemini doesn't fail the request. + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i], BaseModel): + msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore + else: + msg_dict = messages[msg_i] # type: ignore + assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore + _message_content = assistant_msg.get("content", None) + if _message_content is not None and isinstance(_message_content, list): + _parts = [] + for element in _message_content: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + assistant_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and _message_content + ): + assistant_text = _message_content # either string or none + assistant_content.append(PartType(text=assistant_text)) # type: ignore + + ## HANDLE ASSISTANT FUNCTION CALL + if ( + assistant_msg.get("tool_calls", []) is not None + or assistant_msg.get("function_call") is not None + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke(assistant_msg) + ) + last_message_with_tool_calls = assistant_msg + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + tool_call_message_roles = ["tool", "function"] + if ( + msg_i < len(messages) + and messages[msg_i]["role"] in tool_call_message_roles + ): + _part = convert_to_gemini_tool_call_result( + messages[msg_i], last_message_with_tool_calls # type: ignore + ) + msg_i += 1 + tool_call_responses.append(_part) + if msg_i < len(messages) and ( + messages[msg_i]["role"] not in tool_call_message_roles + ): + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + tool_call_responses = [] + + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + return contents + except Exception as e: + raise e def _transform_request_body( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 0c2ccb7d8..39c63dbb3 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -35,13 +35,6 @@ from litellm.llms.custom_httpx.http_handler import ( HTTPHandler, get_async_httpx_client, ) -from litellm.llms.prompt_templates.factory import ( - convert_url_to_base64, - response_schema_prompt, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( - _gemini_convert_messages_with_history, -) from litellm.types.llms.openai import ( ChatCompletionResponseMessage, ChatCompletionToolCallChunk, @@ -57,6 +50,7 @@ from litellm.types.llms.vertex_ai import ( GenerateContentResponseBody, GenerationConfig, HttpxPartType, + LogprobsResult, PartType, RequestBody, SafetSettingsConfig, @@ -64,7 +58,12 @@ from litellm.types.llms.vertex_ai import ( ToolConfig, Tools, ) -from litellm.types.utils import GenericStreamingChunk +from litellm.types.utils import ( + ChatCompletionTokenLogprob, + ChoiceLogprobs, + GenericStreamingChunk, + TopLogprob, +) from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ....utils import _remove_additional_properties, _remove_strict_from_schema @@ -365,6 +364,7 @@ class VertexGeminiConfig: "presence_penalty", "extra_headers", "seed", + "logprobs", ] def map_tool_choice_values( @@ -454,6 +454,16 @@ class VertexGeminiConfig: _tools["code_execution"] = code_execution return [_tools] + def _map_response_schema(self, value: dict) -> dict: + old_schema = deepcopy(value) + if isinstance(old_schema, list): + for item in old_schema: + if isinstance(item, dict): + item = _build_vertex_schema(parameters=item) + elif isinstance(old_schema, dict): + old_schema = _build_vertex_schema(parameters=old_schema) + return old_schema + def map_openai_params( self, model: str, @@ -461,6 +471,7 @@ class VertexGeminiConfig: optional_params: dict, drop_params: bool, ): + for param, value in non_default_params.items(): if param == "temperature": optional_params["temperature"] = value @@ -499,19 +510,15 @@ class VertexGeminiConfig: if "response_schema" in optional_params and isinstance( optional_params["response_schema"], dict ): - old_schema = deepcopy(optional_params["response_schema"]) - - if isinstance(old_schema, list): - for item in old_schema: - if isinstance(item, dict): - item = _build_vertex_schema(parameters=item) - elif isinstance(old_schema, dict): - old_schema = _build_vertex_schema(parameters=old_schema) - optional_params["response_schema"] = old_schema + optional_params["response_schema"] = self._map_response_schema( + value=optional_params["response_schema"] + ) if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": optional_params["presence_penalty"] = value + if param == "logprobs": + optional_params["responseLogprobs"] = value if (param == "tools" or param == "functions") and isinstance(value, list): optional_params["tools"] = self._map_function(value=value) optional_params["litellm_param_is_function_call"] = ( @@ -527,6 +534,7 @@ class VertexGeminiConfig: optional_params["tool_choice"] = _tool_choice_value if param == "seed": optional_params["seed"] = value + return optional_params def get_mapped_special_auth_params(self) -> dict: @@ -584,12 +592,325 @@ class VertexGeminiConfig: ) return exception_string - def get_assistant_content_message(self, parts: List[HttpxPartType]) -> str: - content_str = "" + def get_assistant_content_message( + self, parts: List[HttpxPartType] + ) -> Optional[str]: + _content_str = "" for part in parts: if "text" in part: - content_str += part["text"] - return content_str + _content_str += part["text"] + if _content_str: + return _content_str + return None + + def _transform_parts( + self, + parts: List[HttpxPartType], + index: int, + is_function_call: Optional[bool], + ) -> Tuple[ + Optional[ChatCompletionToolCallFunctionChunk], + Optional[List[ChatCompletionToolCallChunk]], + ]: + function: Optional[ChatCompletionToolCallFunctionChunk] = None + _tools: List[ChatCompletionToolCallChunk] = [] + for part in parts: + if "functionCall" in part: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=part["functionCall"]["name"], + arguments=json.dumps(part["functionCall"]["args"]), + ) + if is_function_call is True: + function = _function_chunk + else: + _tool_response_chunk = ChatCompletionToolCallChunk( + id=f"call_{str(uuid.uuid4())}", + type="function", + function=_function_chunk, + index=index, + ) + _tools.append(_tool_response_chunk) + if len(_tools) == 0: + tools: Optional[List[ChatCompletionToolCallChunk]] = None + else: + tools = _tools + return function, tools + + def _transform_logprobs( + self, logprobs_result: Optional[LogprobsResult] + ) -> Optional[ChoiceLogprobs]: + if logprobs_result is None: + return None + if "chosenCandidates" not in logprobs_result: + return None + logprobs_list: List[ChatCompletionTokenLogprob] = [] + for index, candidate in enumerate(logprobs_result["chosenCandidates"]): + top_logprobs: List[TopLogprob] = [] + if "topCandidates" in logprobs_result and index < len( + logprobs_result["topCandidates"] + ): + top_candidates_for_index = logprobs_result["topCandidates"][index][ + "candidates" + ] + + for options in top_candidates_for_index: + top_logprobs.append( + TopLogprob( + token=options["token"], logprob=options["logProbability"] + ) + ) + logprobs_list.append( + ChatCompletionTokenLogprob( + token=candidate["token"], + logprob=candidate["logProbability"], + top_logprobs=top_logprobs, + ) + ) + return ChoiceLogprobs(content=logprobs_list) + + def _handle_blocked_response( + self, + model_response: ModelResponse, + completion_response: GenerateContentResponseBody, + ) -> ModelResponse: + # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt + model_response.choices[0].finish_reason = "content_filter" + + chat_completion_message: ChatCompletionResponseMessage = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=chat_completion_message, # type: ignore + logprobs=None, + enhancements=None, + ) + + model_response.choices = [choice] + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0), + ) + + setattr(model_response, "usage", usage) + + return model_response + + def _handle_content_policy_violation( + self, + model_response: ModelResponse, + completion_response: GenerateContentResponseBody, + ) -> ModelResponse: + ## CONTENT POLICY VIOLATION ERROR + model_response.choices[0].finish_reason = "content_filter" + + _chat_completion_message = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=_chat_completion_message, + logprobs=None, + enhancements=None, + ) + + model_response.choices = [choice] + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0), + ) + + setattr(model_response, "usage", usage) + + return model_response + + def _transform_response( + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, + optional_params: dict, + litellm_params: dict, + api_key: str, + data: Union[dict, str, RequestBody], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + + ## RESPONSE OBJECT + try: + completion_response = GenerateContentResponseBody(**response.json()) # type: ignore + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + ## GET MODEL ## + model_response.model = model + + ## CHECK IF RESPONSE FLAGGED + if ( + "promptFeedback" in completion_response + and "blockReason" in completion_response["promptFeedback"] + ): + return self._handle_blocked_response( + model_response=model_response, + completion_response=completion_response, + ) + + _candidates = completion_response.get("candidates") + if _candidates and len(_candidates) > 0: + content_policy_violations = ( + VertexGeminiConfig().get_flagged_finish_reasons() + ) + if ( + "finishReason" in _candidates[0] + and _candidates[0]["finishReason"] in content_policy_violations.keys() + ): + return self._handle_content_policy_violation( + model_response=model_response, + completion_response=completion_response, + ) + + model_response.choices = [] # type: ignore + + try: + ## CHECK IF GROUNDING METADATA IN REQUEST + grounding_metadata: List[dict] = [] + safety_ratings: List = [] + citation_metadata: List = [] + ## GET TEXT ## + chat_completion_message: ChatCompletionResponseMessage = { + "role": "assistant" + } + chat_completion_logprobs: Optional[ChoiceLogprobs] = None + tools: Optional[List[ChatCompletionToolCallChunk]] = [] + functions: Optional[ChatCompletionToolCallFunctionChunk] = None + if _candidates: + for idx, candidate in enumerate(_candidates): + if "content" not in candidate: + continue + + if "groundingMetadata" in candidate: + grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore + + if "safetyRatings" in candidate: + safety_ratings.append(candidate["safetyRatings"]) + + if "citationMetadata" in candidate: + citation_metadata.append(candidate["citationMetadata"]) + if "parts" in candidate["content"]: + chat_completion_message[ + "content" + ] = VertexGeminiConfig().get_assistant_content_message( + parts=candidate["content"]["parts"] + ) + + functions, tools = self._transform_parts( + parts=candidate["content"]["parts"], + index=candidate.get("index", idx), + is_function_call=litellm_params.get( + "litellm_param_is_function_call" + ), + ) + + if "logprobsResult" in candidate: + chat_completion_logprobs = self._transform_logprobs( + logprobs_result=candidate["logprobsResult"] + ) + + if tools: + chat_completion_message["tool_calls"] = tools + + if functions is not None: + chat_completion_message["function_call"] = functions + choice = litellm.Choices( + finish_reason=candidate.get("finishReason", "stop"), + index=candidate.get("index", idx), + message=chat_completion_message, # type: ignore + logprobs=chat_completion_logprobs, + enhancements=None, + ) + + model_response.choices.append(choice) + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get( + "totalTokenCount", 0 + ), + ) + + setattr(model_response, "usage", usage) + + ## ADD GROUNDING METADATA ## + setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) + model_response._hidden_params[ + "vertex_ai_grounding_metadata" + ] = ( # older approach - maintaining to prevent regressions + grounding_metadata + ) + + ## ADD SAFETY RATINGS ## + setattr(model_response, "vertex_ai_safety_results", safety_ratings) + model_response._hidden_params["vertex_ai_safety_results"] = ( + safety_ratings # older approach - maintaining to prevent regressions + ) + + ## ADD CITATION METADATA ## + setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) + model_response._hidden_params["vertex_ai_citation_metadata"] = ( + citation_metadata # older approach - maintaining to prevent regressions + ) + + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + completion_response, str(e) + ), + status_code=422, + ) + + return model_response class GoogleAIStudioGeminiConfig( @@ -675,6 +996,7 @@ class GoogleAIStudioGeminiConfig( "response_format", "n", "stop", + "logprobs", ] def map_openai_params( @@ -771,243 +1093,6 @@ class VertexLLM(VertexBase): def __init__(self) -> None: super().__init__() - def _process_response( # noqa: PLR0915 - self, - model: str, - response: httpx.Response, - model_response: ModelResponse, - logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, - optional_params: dict, - litellm_params: dict, - api_key: str, - data: Union[dict, str, RequestBody], - messages: List, - print_verbose, - encoding, - ) -> ModelResponse: - - ## LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = GenerateContentResponseBody(**response.json()) # type: ignore - except Exception as e: - raise VertexAIError( - message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( - response.text, str(e) - ), - status_code=422, - ) - - ## GET MODEL ## - model_response.model = model - - ## CHECK IF RESPONSE FLAGGED - if "promptFeedback" in completion_response: - if "blockReason" in completion_response["promptFeedback"]: - # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt - model_response.choices[0].finish_reason = "content_filter" - - chat_completion_message: ChatCompletionResponseMessage = { - "role": "assistant", - "content": None, - } - - choice = litellm.Choices( - finish_reason="content_filter", - index=0, - message=chat_completion_message, # type: ignore - logprobs=None, - enhancements=None, - ) - - model_response.choices = [choice] - - ## GET USAGE ## - usage = litellm.Usage( - prompt_tokens=completion_response["usageMetadata"].get( - "promptTokenCount", 0 - ), - completion_tokens=completion_response["usageMetadata"].get( - "candidatesTokenCount", 0 - ), - total_tokens=completion_response["usageMetadata"].get( - "totalTokenCount", 0 - ), - ) - - setattr(model_response, "usage", usage) - - return model_response - - _candidates = completion_response.get("candidates") - if _candidates and len(_candidates) > 0: - content_policy_violations = ( - VertexGeminiConfig().get_flagged_finish_reasons() - ) - if ( - "finishReason" in _candidates[0] - and _candidates[0]["finishReason"] in content_policy_violations.keys() - ): - ## CONTENT POLICY VIOLATION ERROR - model_response.choices[0].finish_reason = "content_filter" - - _chat_completion_message = { - "role": "assistant", - "content": None, - } - - choice = litellm.Choices( - finish_reason="content_filter", - index=0, - message=_chat_completion_message, - logprobs=None, - enhancements=None, - ) - - model_response.choices = [choice] - - ## GET USAGE ## - usage = litellm.Usage( - prompt_tokens=completion_response["usageMetadata"].get( - "promptTokenCount", 0 - ), - completion_tokens=completion_response["usageMetadata"].get( - "candidatesTokenCount", 0 - ), - total_tokens=completion_response["usageMetadata"].get( - "totalTokenCount", 0 - ), - ) - - setattr(model_response, "usage", usage) - - return model_response - - model_response.choices = [] # type: ignore - - try: - ## CHECK IF GROUNDING METADATA IN REQUEST - grounding_metadata: List[dict] = [] - safety_ratings: List = [] - citation_metadata: List = [] - ## GET TEXT ## - chat_completion_message = {"role": "assistant"} - content_str: str = "" - tools: List[ChatCompletionToolCallChunk] = [] - functions: Optional[ChatCompletionToolCallFunctionChunk] = None - if _candidates: - for idx, candidate in enumerate(_candidates): - if "content" not in candidate: - continue - - if "groundingMetadata" in candidate: - grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore - - if "safetyRatings" in candidate: - safety_ratings.append(candidate["safetyRatings"]) - - if "citationMetadata" in candidate: - citation_metadata.append(candidate["citationMetadata"]) - if "parts" in candidate["content"]: - content_str = ( - VertexGeminiConfig().get_assistant_content_message( - parts=candidate["content"]["parts"] - ) - ) - - if ( - "parts" in candidate["content"] - and "functionCall" in candidate["content"]["parts"][0] - ): - _function_chunk = ChatCompletionToolCallFunctionChunk( - name=candidate["content"]["parts"][0]["functionCall"][ - "name" - ], - arguments=json.dumps( - candidate["content"]["parts"][0]["functionCall"]["args"] - ), - ) - if litellm_params.get("litellm_param_is_function_call") is True: - functions = _function_chunk - else: - _tool_response_chunk = ChatCompletionToolCallChunk( - id=f"call_{str(uuid.uuid4())}", - type="function", - function=_function_chunk, - index=candidate.get("index", idx), - ) - tools.append(_tool_response_chunk) - chat_completion_message["content"] = ( - content_str if len(content_str) > 0 else None - ) - if len(tools) > 0: - chat_completion_message["tool_calls"] = tools - - if functions is not None: - chat_completion_message["function_call"] = functions - choice = litellm.Choices( - finish_reason=candidate.get("finishReason", "stop"), - index=candidate.get("index", idx), - message=chat_completion_message, # type: ignore - logprobs=None, - enhancements=None, - ) - - model_response.choices.append(choice) - - ## GET USAGE ## - usage = litellm.Usage( - prompt_tokens=completion_response["usageMetadata"].get( - "promptTokenCount", 0 - ), - completion_tokens=completion_response["usageMetadata"].get( - "candidatesTokenCount", 0 - ), - total_tokens=completion_response["usageMetadata"].get( - "totalTokenCount", 0 - ), - ) - - setattr(model_response, "usage", usage) - - ## ADD GROUNDING METADATA ## - setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) - model_response._hidden_params[ - "vertex_ai_grounding_metadata" - ] = ( # older approach - maintaining to prevent regressions - grounding_metadata - ) - - ## ADD SAFETY RATINGS ## - setattr(model_response, "vertex_ai_safety_results", safety_ratings) - model_response._hidden_params["vertex_ai_safety_results"] = ( - safety_ratings # older approach - maintaining to prevent regressions - ) - - ## ADD CITATION METADATA ## - setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) - model_response._hidden_params["vertex_ai_citation_metadata"] = ( - citation_metadata # older approach - maintaining to prevent regressions - ) - - except Exception as e: - raise VertexAIError( - message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( - completion_response, str(e) - ), - status_code=422, - ) - - return model_response - async def async_streaming( self, model: str, @@ -1171,7 +1256,7 @@ class VertexLLM(VertexBase): except httpx.TimeoutException: raise VertexAIError(status_code=408, message="Timeout error occurred.") - return self._process_response( + return VertexGeminiConfig()._transform_response( model=model, response=response, model_response=model_response, @@ -1359,7 +1444,7 @@ class VertexLLM(VertexBase): except httpx.TimeoutException: raise VertexAIError(status_code=408, message="Timeout error occurred.") - return self._process_response( + return VertexGeminiConfig()._transform_response( model=model, response=response, model_response=model_response, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py index 5b50868a8..80295ec40 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py @@ -85,185 +85,6 @@ class TextStreamer: raise StopAsyncIteration # once we run out of data to stream, we raise this error -def _get_image_bytes_from_url(image_url: str) -> bytes: - try: - response = requests.get(image_url) - response.raise_for_status() # Raise an error for bad responses (4xx and 5xx) - image_bytes = response.content - return image_bytes - except requests.exceptions.RequestException as e: - raise Exception(f"An exception occurs with this image - {str(e)}") - - -def _convert_gemini_role(role: str) -> Literal["user", "model"]: - if role == "user": - return "user" - else: - return "model" - - -def _process_gemini_image(image_url: str) -> PartType: - try: - # GCS URIs - if "gs://" in image_url: - # Figure out file type - extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" - extension = extension_with_dot[1:] # Ex: "png" - - file_type = get_file_type_from_extension(extension) - - # Validate the file type is supported by Gemini - if not is_gemini_1_5_accepted_file_type(file_type): - raise Exception(f"File type not supported by gemini - {file_type}") - - mime_type = get_file_mime_type_for_file_type(file_type) - file_data = FileDataType(mime_type=mime_type, file_uri=image_url) - - return PartType(file_data=file_data) - - # Direct links - elif "https:/" in image_url or "base64" in image_url: - image = convert_to_anthropic_image_obj(image_url) - _blob = BlobType(data=image["data"], mime_type=image["media_type"]) - return PartType(inline_data=_blob) - raise Exception("Invalid image received - {}".format(image_url)) - except Exception as e: - raise e - - -def _gemini_convert_messages_with_history( # noqa: PLR0915 - messages: List[AllMessageValues], -) -> List[ContentType]: - """ - Converts given messages from OpenAI format to Gemini format - - - Parts must be iterable - - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) - - Please ensure that function response turn comes immediately after a function call turn - """ - user_message_types = {"user", "system"} - contents: List[ContentType] = [] - - last_message_with_tool_calls = None - - msg_i = 0 - try: - while msg_i < len(messages): - user_content: List[PartType] = [] - init_msg_i = msg_i - ## MERGE CONSECUTIVE USER CONTENT ## - while ( - msg_i < len(messages) and messages[msg_i]["role"] in user_message_types - ): - _message_content = messages[msg_i].get("content") - if _message_content is not None and isinstance(_message_content, list): - _parts: List[PartType] = [] - for element in _message_content: - if ( - element["type"] == "text" - and "text" in element - and len(element["text"]) > 0 - ): - element = cast(ChatCompletionTextObject, element) - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - element = cast(ChatCompletionImageObject, element) - img_element = element - if isinstance(img_element["image_url"], dict): - image_url = img_element["image_url"]["url"] - else: - image_url = img_element["image_url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) - user_content.extend(_parts) - elif ( - _message_content is not None - and isinstance(_message_content, str) - and len(_message_content) > 0 - ): - _part = PartType(text=_message_content) - user_content.append(_part) - - msg_i += 1 - - if user_content: - """ - check that user_content has 'text' parameter. - - Known Vertex Error: Unable to submit request because it must have a text parameter. - - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515 - """ - has_text_in_content = _check_text_in_content(user_content) - if has_text_in_content is False: - verbose_logger.warning( - "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515" - ) - user_content.append( - PartType(text=" ") - ) # add a blank text, to ensure Gemini doesn't fail the request. - contents.append(ContentType(role="user", parts=user_content)) - assistant_content = [] - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if isinstance(messages[msg_i], BaseModel): - msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore - else: - msg_dict = messages[msg_i] # type: ignore - assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore - _message_content = assistant_msg.get("content", None) - if _message_content is not None and isinstance(_message_content, list): - _parts = [] - for element in _message_content: - if isinstance(element, dict): - if element["type"] == "text": - _part = PartType(text=element["text"]) - _parts.append(_part) - assistant_content.extend(_parts) - elif ( - _message_content is not None - and isinstance(_message_content, str) - and _message_content - ): - assistant_text = _message_content # either string or none - assistant_content.append(PartType(text=assistant_text)) # type: ignore - - ## HANDLE ASSISTANT FUNCTION CALL - if ( - assistant_msg.get("tool_calls", []) is not None - or assistant_msg.get("function_call") is not None - ): # support assistant tool invoke conversion - assistant_content.extend( - convert_to_gemini_tool_call_invoke(assistant_msg) - ) - last_message_with_tool_calls = assistant_msg - - msg_i += 1 - - if assistant_content: - contents.append(ContentType(role="model", parts=assistant_content)) - - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and ( - messages[msg_i]["role"] == "tool" - or messages[msg_i]["role"] == "function" - ): - _part = convert_to_gemini_tool_call_result( - messages[msg_i], last_message_with_tool_calls # type: ignore - ) - contents.append(ContentType(parts=[_part])) # type: ignore - msg_i += 1 - - if msg_i == init_msg_i: # prevent infinite loops - raise Exception( - "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( - messages[msg_i] - ) - ) - return contents - except Exception as e: - raise e - - def _get_client_cache_key( model: str, vertex_project: Optional[str], vertex_location: Optional[str] ): @@ -487,91 +308,7 @@ def completion( # noqa: PLR0915 return async_completion(**data) completion_response = None - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_with_history(messages=messages) - stream = optional_params.pop("stream", False) - if stream is True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - _model_response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return _model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call - response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - if tools is not None and bool( - getattr(response.candidates[0].content.parts[0], "function_call", None) - ): - function_call = response.candidates[0].content.parts[0].function_call - args_dict = {} - - # Check if it's a RepeatedComposite instance - for key, val in function_call.args.items(): - if isinstance( - val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore - ): - # If so, convert to list - args_dict[key] = [v for v in val] - else: - args_dict[key] = val - - try: - args_str = json.dumps(args_dict) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "arguments": args_str, - "name": function_call.name, - }, - "type": "function", - } - ], - ) - completion_response = message - else: - completion_response = response.text - response_obj = response._raw_response - optional_params["tools"] = tools - elif mode == "chat": + if mode == "chat": chat = llm_model.start_chat() request_str += "chat = llm_model.start_chat()\n" @@ -796,82 +533,7 @@ async def async_completion( # noqa: PLR0915 response_obj = None completion_response = None - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - optional_params.pop("stream", False) - - content = _gemini_convert_messages_with_history(messages=messages) - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call - # print(f"final content: {content}") - response = await llm_model._generate_content_async( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - _cache_key = _get_client_cache_key( - model=model, - vertex_project=vertex_project, - vertex_location=vertex_location, - ) - _set_client_in_cache( - client_cache_key=_cache_key, vertex_llm_model=llm_model - ) - - if tools is not None and bool( - getattr(response.candidates[0].content.parts[0], "function_call", None) - ): - function_call = response.candidates[0].content.parts[0].function_call - args_dict = {} - - # Check if it's a RepeatedComposite instance - for key, val in function_call.args.items(): - if isinstance( - val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore - ): - # If so, convert to list - args_dict[key] = [v for v in val] - else: - args_dict[key] = val - - try: - args_str = json.dumps(args_dict) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "arguments": args_str, - "name": function_call.name, - }, - "type": "function", - } - ], - ) - completion_response = message - else: - completion_response = response.text - response_obj = response._raw_response - optional_params["tools"] = tools - elif mode == "chat": + if mode == "chat": # chat-bison etc. chat = llm_model.start_chat() ## LOGGING @@ -1032,32 +694,7 @@ async def async_streaming( # noqa: PLR0915 Add support for async streaming calls for gemini-pro """ response: Any = None - if mode == "vision": - stream = optional_params.pop("stream") - tools = optional_params.pop("tools", None) - print_verbose("\nMaking VertexAI Gemini Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - - content = _gemini_convert_messages_with_history(messages=messages) - - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - response = await llm_model._generate_content_streaming_async( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - elif mode == "chat": + if mode == "chat": chat = llm_model.start_chat() optional_params.pop( "stream", None diff --git a/litellm/main.py b/litellm/main.py index 34e4ae5bb..a964ba7e6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -15,7 +15,6 @@ import json import os import random import sys -import threading import time import traceback import uuid diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ca198d7a3..88de3eb47 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -10,10 +10,9 @@ model_list: output_cost_per_token: 0.000015 # 15$/M api_base: "https://exampleopenaiendpoint-production.up.railway.app" api_key: my-fake-key - - model_name: my-custom-model + - model_name: gemini-1.5-flash-002 litellm_params: - model: my-custom-llm/my-custom-model - api_key: my-fake-key + model: gemini/gemini-1.5-flash-002 litellm_settings: fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 4cfaf490f..0a7ae541d 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -97,9 +97,9 @@ class PassThroughEndpointLogging: if "generateContent" in url_route: model = self.extract_model_from_url(url_route) - instance_of_vertex_llm = VertexLLM() + instance_of_vertex_llm = litellm.VertexGeminiConfig() litellm_model_response: litellm.ModelResponse = ( - instance_of_vertex_llm._process_response( + instance_of_vertex_llm._transform_response( model=model, messages=[ {"role": "user", "content": "no-message-pass-through-endpoint"} diff --git a/litellm/tests/test_langtrace.py b/litellm/tests/test_langtrace.py deleted file mode 100644 index 803bae521..000000000 --- a/litellm/tests/test_langtrace.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -import sys -import time - -import pytest -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from langtrace_python_sdk import langtrace - -import litellm - -sys.path.insert(0, os.path.abspath("../..")) - - -@pytest.fixture() -def exporter(): - exporter = InMemorySpanExporter() - langtrace.init(batch=False, custom_remote_exporter=exporter) - litellm.success_callback = ["langtrace"] - litellm.set_verbose = True - - return exporter - - -@pytest.mark.parametrize("model", ["claude-2.1", "gpt-3.5-turbo"]) -def test_langtrace_logging(exporter, model): - litellm.completion( - model=model, - messages=[{"role": "user", "content": "This is a test"}], - max_tokens=1000, - temperature=0.7, - timeout=5, - mock_response="hi", - ) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 3b95a3282..c2a78e349 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -500,9 +500,9 @@ ChatCompletionAssistantContentValue = ( class ChatCompletionResponseMessage(TypedDict, total=False): content: Optional[ChatCompletionAssistantContentValue] - tool_calls: List[ChatCompletionToolCallChunk] + tool_calls: Optional[List[ChatCompletionToolCallChunk]] role: Literal["assistant"] - function_call: ChatCompletionToolCallFunctionChunk + function_call: Optional[ChatCompletionToolCallFunctionChunk] class ChatCompletionUsageBlock(TypedDict): diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 1b906023b..d55cf3ec6 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -167,6 +167,8 @@ class GenerationConfig(TypedDict, total=False): response_mime_type: Literal["text/plain", "application/json"] response_schema: dict seed: int + responseLogprobs: bool + logprobs: int class Tools(TypedDict, total=False): @@ -270,6 +272,21 @@ class GroundingMetadata(TypedDict, total=False): groundingAttributions: List[dict] +class LogprobsCandidate(TypedDict): + token: str + tokenId: int + logProbability: float + + +class LogprobsTopCandidate(TypedDict): + candidates: List[LogprobsCandidate] + + +class LogprobsResult(TypedDict, total=False): + topCandidates: List[LogprobsTopCandidate] + chosenCandidates: List[LogprobsCandidate] + + class Candidates(TypedDict, total=False): index: int content: HttpxContentType @@ -288,6 +305,7 @@ class Candidates(TypedDict, total=False): citationMetadata: CitationMetadata groundingMetadata: GroundingMetadata finishMessage: str + logprobsResult: LogprobsResult class PromptFeedback(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 11613e24d..3922dcb58 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7708,10 +7708,17 @@ class CustomStreamWrapper: continue ## LOGGING ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(processed_chunk, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + result=processed_chunk, + start_time=None, + end_time=None, + cache_hit=cache_hit, + ) + # threading.Thread( + # target=self.logging_obj.success_handler, + # args=(processed_chunk, None, None, cache_hit), + # ).start() # log response asyncio.create_task( self.logging_obj.async_success_handler( processed_chunk, cache_hit=cache_hit diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index fdda7b171..d921c1c17 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -794,7 +794,7 @@ def test_anthropic_parallel_tool_calls(provider): parallel_tool_calls=True, ) print(f"optional_params: {optional_params}") - assert optional_params["tool_choice"]["disable_parallel_tool_use"] is True + assert optional_params["tool_choice"]["disable_parallel_tool_use"] is False def test_anthropic_computer_tool_use(): diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index c5699d8ce..467be4ddf 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -230,3 +230,944 @@ def test_function_calling_with_gemini(): ] } ] + + +def test_multiple_function_call(): + litellm.set_verbose = True + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + messages = [ + {"role": "user", "content": [{"type": "text", "text": "do test"}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "test"}], + "tool_calls": [ + { + "index": 0, + "function": {"arguments": '{"arg": "test"}', "name": "test"}, + "id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec", + "type": "function", + }, + { + "index": 1, + "function": {"arguments": '{"arg": "test2"}', "name": "test2"}, + "id": "call_2414e8f9-283a-002b-182a-1290ab912c02", + "type": "function", + }, + ], + }, + { + "tool_call_id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec", + "role": "tool", + "name": "test", + "content": [{"type": "text", "text": "42"}], + }, + { + "tool_call_id": "call_2414e8f9-283a-002b-182a-1290ab912c02", + "role": "tool", + "name": "test2", + "content": [{"type": "text", "text": "15"}], + }, + {"role": "user", "content": [{"type": "text", "text": "tell me the results."}]}, + ] + + response_body = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": 'The `default_api.test` function call returned a JSON object indicating a successful execution. The `fields` key contains a nested dictionary with a `key` of "content" and a `value` with a `string_value` of "42".\n\nSimilarly, the `default_api.test2` function call also returned a JSON object showing successful execution. The `fields` key contains a nested dictionary with a `key` of "content" and a `value` with a `string_value` of "15".\n\nIn short, both test functions executed successfully and returned different numerical string values ("42" and "15"). The significance of these numbers depends on the internal logic of the `test` and `test2` functions within the `default_api`.\n' + } + ], + "role": "model", + }, + "finishReason": "STOP", + "avgLogprobs": -0.20577410289219447, + } + ], + "usageMetadata": { + "promptTokenCount": 128, + "candidatesTokenCount": 168, + "totalTokenCount": 296, + }, + "modelVersion": "gemini-1.5-flash-002", + } + + mock_response = MagicMock() + mock_response.json.return_value = response_body + + with patch.object(client, "post", return_value=mock_response) as mock_post: + r = litellm.completion( + messages=messages, model="gemini/gemini-1.5-flash-002", client=client + ) + assert len(r.choices) > 0 + + assert mock_post.call_args.kwargs["json"] == { + "contents": [ + {"role": "user", "parts": [{"text": "do test"}]}, + { + "role": "model", + "parts": [ + {"text": "test"}, + { + "function_call": { + "name": "test", + "args": { + "fields": { + "key": "arg", + "value": {"string_value": "test"}, + } + }, + } + }, + { + "function_call": { + "name": "test2", + "args": { + "fields": { + "key": "arg", + "value": {"string_value": "test2"}, + } + }, + } + }, + ], + }, + { + "parts": [ + { + "function_response": { + "name": "test", + "response": { + "fields": { + "key": "content", + "value": {"string_value": "42"}, + } + }, + } + }, + { + "function_response": { + "name": "test2", + "response": { + "fields": { + "key": "content", + "value": {"string_value": "15"}, + } + }, + } + }, + ] + }, + {"role": "user", "parts": [{"text": "tell me the results."}]}, + ], + "generationConfig": {}, + } + + +def test_multiple_function_call_changed_text_pos(): + litellm.set_verbose = True + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + messages = [ + {"role": "user", "content": [{"type": "text", "text": "do test"}]}, + { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": '{"arg": "test"}', "name": "test"}, + "id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec", + "type": "function", + }, + { + "index": 1, + "function": {"arguments": '{"arg": "test2"}', "name": "test2"}, + "id": "call_2414e8f9-283a-002b-182a-1290ab912c02", + "type": "function", + }, + ], + "role": "assistant", + "content": [{"type": "text", "text": "test"}], + }, + { + "tool_call_id": "call_2414e8f9-283a-002b-182a-1290ab912c02", + "role": "tool", + "name": "test2", + "content": [{"type": "text", "text": "15"}], + }, + { + "tool_call_id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec", + "role": "tool", + "name": "test", + "content": [{"type": "text", "text": "42"}], + }, + {"role": "user", "content": [{"type": "text", "text": "tell me the results."}]}, + ] + + response_body = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": 'The code executed two functions, `test` and `test2`.\n\n* **`test`**: Returned a dictionary indicating that the "key" field has a "value" field containing a string value of "42". This is likely a response from a function that processed the input "test" and returned a calculated or pre-defined value.\n\n* **`test2`**: Returned a dictionary indicating that the "key" field has a "value" field containing a string value of "15". Similar to `test`, this suggests a function that processes the input "test2" and returns a specific result.\n\nIn short, both functions appear to be simple tests that return different hardcoded or calculated values based on their input arguments.\n' + } + ], + "role": "model", + }, + "finishReason": "STOP", + "avgLogprobs": -0.32848488592332409, + } + ], + "usageMetadata": { + "promptTokenCount": 128, + "candidatesTokenCount": 155, + "totalTokenCount": 283, + }, + "modelVersion": "gemini-1.5-flash-002", + } + mock_response = MagicMock() + mock_response.json.return_value = response_body + + with patch.object(client, "post", return_value=mock_response) as mock_post: + resp = litellm.completion( + messages=messages, model="gemini/gemini-1.5-flash-002", client=client + ) + assert len(resp.choices) > 0 + mock_post.assert_called_once() + + assert mock_post.call_args.kwargs["json"]["contents"] == [ + {"role": "user", "parts": [{"text": "do test"}]}, + { + "role": "model", + "parts": [ + {"text": "test"}, + { + "function_call": { + "name": "test", + "args": { + "fields": { + "key": "arg", + "value": {"string_value": "test"}, + } + }, + } + }, + { + "function_call": { + "name": "test2", + "args": { + "fields": { + "key": "arg", + "value": {"string_value": "test2"}, + } + }, + } + }, + ], + }, + { + "parts": [ + { + "function_response": { + "name": "test2", + "response": { + "fields": { + "key": "content", + "value": {"string_value": "15"}, + } + }, + } + }, + { + "function_response": { + "name": "test", + "response": { + "fields": { + "key": "content", + "value": {"string_value": "42"}, + } + }, + } + }, + ] + }, + {"role": "user", "parts": [{"text": "tell me the results."}]}, + ] + + +def test_function_calling_with_gemini_multiple_results(): + litellm.set_verbose = True + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + # Step 1: send the conversation and available functions to the model + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses", + } + ] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + response_body = { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "get_current_weather", + "args": {"location": "San Francisco"}, + } + }, + { + "functionCall": { + "name": "get_current_weather", + "args": {"location": "Tokyo"}, + } + }, + { + "functionCall": { + "name": "get_current_weather", + "args": {"location": "Paris"}, + } + }, + ], + "role": "model", + }, + "finishReason": "STOP", + "avgLogprobs": -0.0040788948535919189, + } + ], + "usageMetadata": { + "promptTokenCount": 90, + "candidatesTokenCount": 22, + "totalTokenCount": 112, + }, + "modelVersion": "gemini-1.5-flash-002", + } + + mock_response = MagicMock() + mock_response.json.return_value = response_body + + with patch.object(client, "post", return_value=mock_response): + response = litellm.completion( + model="gemini/gemini-1.5-flash-002", + messages=messages, + tools=tools, + tool_choice="required", + client=client, + ) + print("Response\n", response) + + assert len(response.choices[0].message.tool_calls) == 3 + + expected_locations = ["San Francisco", "Tokyo", "Paris"] + for idx, tool_call in enumerate(response.choices[0].message.tool_calls): + json_args = json.loads(tool_call.function.arguments) + assert json_args["location"] == expected_locations[idx] + + +def test_logprobs_unit_test(): + from litellm import VertexGeminiConfig + + result = VertexGeminiConfig()._transform_logprobs( + logprobs_result={ + "topCandidates": [ + { + "candidates": [ + {"token": "```", "logProbability": -1.5496514e-06}, + {"token": "`", "logProbability": -13.375002}, + {"token": "``", "logProbability": -21.875002}, + ] + }, + { + "candidates": [ + {"token": "tool", "logProbability": 0}, + {"token": "too", "logProbability": -29.031433}, + {"token": "to", "logProbability": -34.11199}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "code", "logProbability": 0}, + {"token": "co", "logProbability": -28.114716}, + {"token": "c", "logProbability": -29.283161}, + ] + }, + { + "candidates": [ + {"token": "\n", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "print", "logProbability": 0}, + {"token": "p", "logProbability": -19.7494}, + {"token": "prin", "logProbability": -21.117342}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "default", "logProbability": 0}, + {"token": "get", "logProbability": -16.811178}, + {"token": "ge", "logProbability": -19.031078}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "api", "logProbability": 0}, + {"token": "ap", "logProbability": -26.501019}, + {"token": "a", "logProbability": -30.905857}, + ] + }, + { + "candidates": [ + {"token": ".", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "get", "logProbability": 0}, + {"token": "ge", "logProbability": -19.984676}, + {"token": "g", "logProbability": -20.527714}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "current", "logProbability": 0}, + {"token": "cur", "logProbability": -28.193565}, + {"token": "cu", "logProbability": -29.636738}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "weather", "logProbability": 0}, + {"token": "we", "logProbability": -27.887215}, + {"token": "wea", "logProbability": -31.851082}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "location", "logProbability": 0}, + {"token": "loc", "logProbability": -19.152641}, + {"token": " location", "logProbability": -21.981709}, + ] + }, + { + "candidates": [ + {"token": '="', "logProbability": -0.034490786}, + {"token": "='", "logProbability": -3.398928}, + {"token": "=", "logProbability": -7.6194153}, + ] + }, + { + "candidates": [ + {"token": "San", "logProbability": -6.5561944e-06}, + {"token": '\\"', "logProbability": -12.015556}, + {"token": "Paris", "logProbability": -14.647776}, + ] + }, + { + "candidates": [ + {"token": " Francisco", "logProbability": -3.5760596e-07}, + {"token": " Frans", "logProbability": -14.83527}, + {"token": " francisco", "logProbability": -19.796852}, + ] + }, + { + "candidates": [ + {"token": '"))', "logProbability": -6.079254e-06}, + {"token": ",", "logProbability": -12.106029}, + {"token": '",', "logProbability": -14.56927}, + ] + }, + { + "candidates": [ + {"token": "\n", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "print", "logProbability": -0.04140338}, + {"token": "```", "logProbability": -3.2049975}, + {"token": "p", "logProbability": -22.087523}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "default", "logProbability": 0}, + {"token": "get", "logProbability": -20.266342}, + {"token": "de", "logProbability": -20.906395}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "api", "logProbability": 0}, + {"token": "ap", "logProbability": -27.712265}, + {"token": "a", "logProbability": -31.986958}, + ] + }, + { + "candidates": [ + {"token": ".", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "get", "logProbability": 0}, + {"token": "g", "logProbability": -23.569286}, + {"token": "ge", "logProbability": -23.829632}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "current", "logProbability": 0}, + {"token": "cur", "logProbability": -30.125153}, + {"token": "curr", "logProbability": -31.756569}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "weather", "logProbability": 0}, + {"token": "we", "logProbability": -27.743786}, + {"token": "w", "logProbability": -30.594503}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "location", "logProbability": 0}, + {"token": "loc", "logProbability": -21.177715}, + {"token": " location", "logProbability": -22.166002}, + ] + }, + { + "candidates": [ + {"token": '="', "logProbability": -1.5617967e-05}, + {"token": "='", "logProbability": -11.080961}, + {"token": "=", "logProbability": -15.164277}, + ] + }, + { + "candidates": [ + {"token": "Tokyo", "logProbability": -3.0041514e-05}, + {"token": "tokyo", "logProbability": -10.650261}, + {"token": "Paris", "logProbability": -12.096886}, + ] + }, + { + "candidates": [ + {"token": '"))', "logProbability": -1.1922384e-07}, + {"token": '",', "logProbability": -16.61921}, + {"token": ",", "logProbability": -17.911102}, + ] + }, + { + "candidates": [ + {"token": "\n", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "print", "logProbability": -3.5760596e-07}, + {"token": "```", "logProbability": -14.949171}, + {"token": "p", "logProbability": -24.321035}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "default", "logProbability": 0}, + {"token": "de", "logProbability": -27.885206}, + {"token": "def", "logProbability": -28.40597}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "api", "logProbability": 0}, + {"token": "ap", "logProbability": -25.905933}, + {"token": "a", "logProbability": -30.408901}, + ] + }, + { + "candidates": [ + {"token": ".", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "get", "logProbability": 0}, + {"token": "g", "logProbability": -22.274963}, + {"token": "ge", "logProbability": -23.285828}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "current", "logProbability": 0}, + {"token": "cur", "logProbability": -28.442535}, + {"token": "curr", "logProbability": -29.95087}, + ] + }, + { + "candidates": [ + {"token": "_", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "weather", "logProbability": 0}, + {"token": "we", "logProbability": -27.307909}, + {"token": "w", "logProbability": -31.076736}, + ] + }, + { + "candidates": [ + {"token": "(", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "location", "logProbability": 0}, + {"token": "loc", "logProbability": -21.535915}, + {"token": "lo", "logProbability": -23.028284}, + ] + }, + { + "candidates": [ + {"token": '="', "logProbability": -8.821511e-06}, + {"token": "='", "logProbability": -11.700986}, + {"token": "=", "logProbability": -14.50358}, + ] + }, + { + "candidates": [ + {"token": "Paris", "logProbability": 0}, + {"token": "paris", "logProbability": -18.07075}, + {"token": "Par", "logProbability": -21.911625}, + ] + }, + { + "candidates": [ + {"token": '"))', "logProbability": 0}, + {"token": '")', "logProbability": -17.916853}, + {"token": ",", "logProbability": -18.318272}, + ] + }, + { + "candidates": [ + {"token": "\n", "logProbability": 0}, + {"token": "ont", "logProbability": -1.2676506e30}, + {"token": " п", "logProbability": -1.2676506e30}, + ] + }, + { + "candidates": [ + {"token": "```", "logProbability": -3.5763796e-06}, + {"token": "print", "logProbability": -12.535343}, + {"token": "``", "logProbability": -19.670813}, + ] + }, + ], + "chosenCandidates": [ + {"token": "```", "logProbability": -1.5496514e-06}, + {"token": "tool", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "code", "logProbability": 0}, + {"token": "\n", "logProbability": 0}, + {"token": "print", "logProbability": 0}, + {"token": "(", "logProbability": 0}, + {"token": "default", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "api", "logProbability": 0}, + {"token": ".", "logProbability": 0}, + {"token": "get", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "current", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "weather", "logProbability": 0}, + {"token": "(", "logProbability": 0}, + {"token": "location", "logProbability": 0}, + {"token": '="', "logProbability": -0.034490786}, + {"token": "San", "logProbability": -6.5561944e-06}, + {"token": " Francisco", "logProbability": -3.5760596e-07}, + {"token": '"))', "logProbability": -6.079254e-06}, + {"token": "\n", "logProbability": 0}, + {"token": "print", "logProbability": -0.04140338}, + {"token": "(", "logProbability": 0}, + {"token": "default", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "api", "logProbability": 0}, + {"token": ".", "logProbability": 0}, + {"token": "get", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "current", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "weather", "logProbability": 0}, + {"token": "(", "logProbability": 0}, + {"token": "location", "logProbability": 0}, + {"token": '="', "logProbability": -1.5617967e-05}, + {"token": "Tokyo", "logProbability": -3.0041514e-05}, + {"token": '"))', "logProbability": -1.1922384e-07}, + {"token": "\n", "logProbability": 0}, + {"token": "print", "logProbability": -3.5760596e-07}, + {"token": "(", "logProbability": 0}, + {"token": "default", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "api", "logProbability": 0}, + {"token": ".", "logProbability": 0}, + {"token": "get", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "current", "logProbability": 0}, + {"token": "_", "logProbability": 0}, + {"token": "weather", "logProbability": 0}, + {"token": "(", "logProbability": 0}, + {"token": "location", "logProbability": 0}, + {"token": '="', "logProbability": -8.821511e-06}, + {"token": "Paris", "logProbability": 0}, + {"token": '"))', "logProbability": 0}, + {"token": "\n", "logProbability": 0}, + {"token": "```", "logProbability": -3.5763796e-06}, + ], + } + ) + + print(result) + + +def test_logprobs(): + litellm.set_verbose = True + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + response_body = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I do not have access to real-time information, including current weather conditions. To get the current weather in San Francisco, I recommend checking a reliable weather website or app such as Google Weather, AccuWeather, or the National Weather Service.\n" + } + ], + "role": "model", + }, + "finishReason": "STOP", + "avgLogprobs": -0.04666396617889404, + "logprobsResult": { + "chosenCandidates": [ + {"token": "I", "logProbability": -1.08472495e-05}, + {"token": " do", "logProbability": -0.00012611414}, + {"token": " not", "logProbability": 0}, + {"token": " have", "logProbability": 0}, + {"token": " access", "logProbability": -0.0008849616}, + {"token": " to", "logProbability": 0}, + {"token": " real", "logProbability": -1.1922384e-07}, + {"token": "-", "logProbability": 0}, + {"token": "time", "logProbability": 0}, + {"token": " information", "logProbability": -2.2409657e-05}, + {"token": ",", "logProbability": 0}, + {"token": " including", "logProbability": 0}, + {"token": " current", "logProbability": -0.14274147}, + {"token": " weather", "logProbability": 0}, + {"token": " conditions", "logProbability": -0.0056300927}, + {"token": ".", "logProbability": -3.5760596e-07}, + {"token": " ", "logProbability": -0.06392521}, + {"token": "To", "logProbability": -2.3844768e-07}, + {"token": " get", "logProbability": -0.058974747}, + {"token": " the", "logProbability": 0}, + {"token": " current", "logProbability": 0}, + {"token": " weather", "logProbability": -2.3844768e-07}, + {"token": " in", "logProbability": -2.3844768e-07}, + {"token": " San", "logProbability": 0}, + {"token": " Francisco", "logProbability": 0}, + {"token": ",", "logProbability": 0}, + {"token": " I", "logProbability": -0.6188003}, + {"token": " recommend", "logProbability": -1.0370523e-05}, + {"token": " checking", "logProbability": -0.00014005086}, + {"token": " a", "logProbability": 0}, + {"token": " reliable", "logProbability": -1.5496514e-06}, + {"token": " weather", "logProbability": -8.344534e-07}, + {"token": " website", "logProbability": -0.0078000566}, + {"token": " or", "logProbability": -1.1922384e-07}, + {"token": " app", "logProbability": 0}, + {"token": " such", "logProbability": -0.9289338}, + {"token": " as", "logProbability": 0}, + {"token": " Google", "logProbability": -0.0046935496}, + {"token": " Weather", "logProbability": 0}, + {"token": ",", "logProbability": 0}, + {"token": " Accu", "logProbability": 0}, + {"token": "Weather", "logProbability": -0.00013909786}, + {"token": ",", "logProbability": 0}, + {"token": " or", "logProbability": -0.31303275}, + {"token": " the", "logProbability": -0.17583296}, + {"token": " National", "logProbability": -0.010806266}, + {"token": " Weather", "logProbability": 0}, + {"token": " Service", "logProbability": 0}, + {"token": ".", "logProbability": -0.00068947335}, + {"token": "\n", "logProbability": 0}, + ] + }, + } + ], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 50, + "totalTokenCount": 61, + }, + "modelVersion": "gemini-1.5-flash-002", + } + mock_response = MagicMock() + mock_response.json.return_value = response_body + + with patch.object(client, "post", return_value=mock_response): + + resp = litellm.completion( + model="gemini/gemini-1.5-flash-002", + messages=[ + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + logprobs=True, + client=client, + ) + print(resp) + + assert resp.choices[0].logprobs is not None diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 072d93da4..2de53696f 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -30,7 +30,7 @@ from litellm import ( completion_cost, embedding, ) -from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( _gemini_convert_messages_with_history, ) from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase @@ -1823,6 +1823,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode): @pytest.mark.flaky(retries=3, delay=1) async def test_gemini_pro_async_function_calling(): load_vertex_ai_credentials() + litellm.set_verbose = True try: tools = [ { @@ -2925,7 +2926,7 @@ def test_gemini_function_call_parameter_in_messages(): def test_gemini_function_call_parameter_in_messages_2(): - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( _gemini_convert_messages_with_history, ) diff --git a/tests/local_testing/test_bedrock_completion.py b/tests/local_testing/test_bedrock_completion.py index 1e65424d2..35a9fc276 100644 --- a/tests/local_testing/test_bedrock_completion.py +++ b/tests/local_testing/test_bedrock_completion.py @@ -1879,13 +1879,16 @@ def test_bedrock_completion_test_4(modify_params): { "role": "assistant", "content": [ + { + "text": """\nThe user is asking about a specific file: main.py. Based on the environment details provided, this file is located in the computer-vision/hm-open3d/src/ directory and is currently open in a VSCode tab.\n\nTo answer the question of what this file is, the most relevant tool would be the read_file tool. This will allow me to examine the contents of main.py to determine its purpose.\n\nThe read_file tool requires the "path" parameter. I can infer this path based on the environment details:\npath: "computer-vision/hm-open3d/src/main.py"\n\nSince I have the necessary parameter, I can proceed with calling the read_file tool.\n""" + }, { "toolUse": { "input": {"path": "computer-vision/hm-open3d/src/main.py"}, "name": "read_file", "toolUseId": "tooluse_qCt-KEyWQlWiyHl26spQVA", } - } + }, ], }, { diff --git a/tests/local_testing/test_function_calling.py b/tests/local_testing/test_function_calling.py index 81d31186c..7946bdfea 100644 --- a/tests/local_testing/test_function_calling.py +++ b/tests/local_testing/test_function_calling.py @@ -473,9 +473,15 @@ def test_anthropic_function_call_with_no_schema(model): completion(model=model, messages=messages, tools=tools, tool_choice="auto") -def test_passing_tool_result_as_list(): +@pytest.mark.parametrize( + "model", + [ + "anthropic/claude-3-5-sonnet-20241022", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + ], +) +def test_passing_tool_result_as_list(model): litellm.set_verbose = True - model = "anthropic/claude-3-5-sonnet-20241022" messages = [ { "content": [ @@ -611,4 +617,5 @@ def test_passing_tool_result_as_list(): resp = completion(model=model, messages=messages, tools=tools) print(resp) - assert resp.usage.prompt_tokens_details.cached_tokens > 0 + if model == "claude-3-5-sonnet-20241022": + assert resp.usage.prompt_tokens_details.cached_tokens > 0 diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index 4098e524a..74182c09f 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -142,6 +142,7 @@ def prisma_client(): @pytest.mark.asyncio() +@pytest.mark.flaky(retries=6, delay=1) async def test_new_user_response(prisma_client): try: @@ -2891,6 +2892,7 @@ async def test_generate_key_with_guardrails(prisma_client): @pytest.mark.asyncio() +@pytest.mark.flaky(retries=6, delay=1) async def test_team_access_groups(prisma_client): """ Test team based model access groups diff --git a/tests/local_testing/test_mem_leak.py b/tests/local_testing/test_mem_leak.py new file mode 100644 index 000000000..60f228f1e --- /dev/null +++ b/tests/local_testing/test_mem_leak.py @@ -0,0 +1,243 @@ +# import io +# import os +# import sys + +# sys.path.insert(0, os.path.abspath("../..")) + +# import litellm +# from memory_profiler import profile +# from litellm.utils import ( +# ModelResponseIterator, +# ModelResponseListIterator, +# CustomStreamWrapper, +# ) +# from litellm.types.utils import ModelResponse, Choices, Message +# import time +# import pytest + + +# # @app.post("/debug") +# # async def debug(body: ExampleRequest) -> str: +# # return await main_logic(body.query) +# def model_response_list_factory(): +# chunks = [ +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# { +# "delta": {"content": "", "role": "assistant"}, +# "finish_reason": None, +# "index": 0, +# } +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# {"delta": {"content": "This"}, "finish_reason": None, "index": 0} +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# {"delta": {"content": " is"}, "finish_reason": None, "index": 0} +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# {"delta": {"content": " a"}, "finish_reason": None, "index": 0} +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# {"delta": {"content": " dummy"}, "finish_reason": None, "index": 0} +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [ +# { +# "delta": {"content": " response"}, +# "finish_reason": None, +# "index": 0, +# } +# ], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "", +# "choices": [ +# { +# "finish_reason": None, +# "index": 0, +# "content_filter_offsets": { +# "check_offset": 35159, +# "start_offset": 35159, +# "end_offset": 36150, +# }, +# "content_filter_results": { +# "hate": {"filtered": False, "severity": "safe"}, +# "self_harm": {"filtered": False, "severity": "safe"}, +# "sexual": {"filtered": False, "severity": "safe"}, +# "violence": {"filtered": False, "severity": "safe"}, +# }, +# } +# ], +# "created": 0, +# "model": "", +# "object": "", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [{"delta": {"content": "."}, "finish_reason": None, "index": 0}], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj", +# "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], +# "created": 1716563849, +# "model": "gpt-4o-2024-05-13", +# "object": "chat.completion.chunk", +# "system_fingerprint": "fp_5f4bad809a", +# }, +# { +# "id": "", +# "choices": [ +# { +# "finish_reason": None, +# "index": 0, +# "content_filter_offsets": { +# "check_offset": 36150, +# "start_offset": 36060, +# "end_offset": 37029, +# }, +# "content_filter_results": { +# "hate": {"filtered": False, "severity": "safe"}, +# "self_harm": {"filtered": False, "severity": "safe"}, +# "sexual": {"filtered": False, "severity": "safe"}, +# "violence": {"filtered": False, "severity": "safe"}, +# }, +# } +# ], +# "created": 0, +# "model": "", +# "object": "", +# }, +# ] + +# chunk_list = [] +# for chunk in chunks: +# new_chunk = litellm.ModelResponse(stream=True, id=chunk["id"]) +# if "choices" in chunk and isinstance(chunk["choices"], list): +# new_choices = [] +# for choice in chunk["choices"]: +# if isinstance(choice, litellm.utils.StreamingChoices): +# _new_choice = choice +# elif isinstance(choice, dict): +# _new_choice = litellm.utils.StreamingChoices(**choice) +# new_choices.append(_new_choice) +# new_chunk.choices = new_choices +# chunk_list.append(new_chunk) + +# return ModelResponseListIterator(model_responses=chunk_list) + + +# async def mock_completion(*args, **kwargs): +# completion_stream = model_response_list_factory() +# return litellm.CustomStreamWrapper( +# completion_stream=completion_stream, +# model="gpt-4-0613", +# custom_llm_provider="cached_response", +# logging_obj=litellm.Logging( +# model="gpt-4-0613", +# messages=[{"role": "user", "content": "Hey"}], +# stream=True, +# call_type="completion", +# start_time=time.time(), +# litellm_call_id="12345", +# function_id="1245", +# ), +# ) + + +# @profile +# async def main_logic() -> str: +# stream = await mock_completion() +# result = "" +# async for chunk in stream: +# result += chunk.choices[0].delta.content or "" +# return result + + +# import asyncio + +# for _ in range(100): +# asyncio.run(main_logic()) + + +# # @pytest.mark.asyncio +# # def test_memory_profile(capsys): +# # # Run the async function +# # result = asyncio.run(main_logic()) + +# # # Verify the result +# # assert result == "This is a dummy response." + +# # # Capture the output +# # captured = capsys.readouterr() + +# # # Print memory output for debugging +# # print("Memory Profiler Output:") +# # print(f"captured out: {captured.out}") + +# # # Basic memory leak checks +# # for idx, line in enumerate(captured.out.split("\n")): +# # if idx % 2 == 0 and "MiB" in line: +# # print(f"line: {line}") + +# # # mem_lines = [line for line in captured.out.split("\n") if "MiB" in line] + +# # print(mem_lines) + +# # # Ensure we have some memory lines +# # assert len(mem_lines) > 0, "No memory profiler output found" + +# # # Optional: Add more specific memory leak detection +# # for line in mem_lines: +# # # Extract memory increment +# # parts = line.split() +# # if len(parts) >= 3: +# # try: +# # mem_increment = float(parts[2].replace("MiB", "")) +# # # Assert that memory increment is below a reasonable threshold +# # assert mem_increment < 1.0, f"Potential memory leak detected: {line}" +# # except (ValueError, IndexError): +# # pass # Skip lines that don't match expected format diff --git a/tests/local_testing/test_prompt_factory.py b/tests/local_testing/test_prompt_factory.py index 7b4e295ce..104997563 100644 --- a/tests/local_testing/test_prompt_factory.py +++ b/tests/local_testing/test_prompt_factory.py @@ -25,7 +25,7 @@ from litellm.llms.prompt_templates.factory import ( from litellm.llms.prompt_templates.common_utils import ( get_completion_messages, ) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( _gemini_convert_messages_with_history, ) from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py index 1baddc7dd..668d4cab4 100644 --- a/tests/local_testing/test_user_api_key_auth.py +++ b/tests/local_testing/test_user_api_key_auth.py @@ -306,6 +306,8 @@ async def test_auth_with_allowed_routes(route, should_raise_error): ("/key/delete", "internal_user", True), ("/key/generate", "internal_user", True), ("/key/82akk800000000jjsk/regenerate", "internal_user", True), + # Internal User Viewer + ("/key/generate", "internal_user_viewer", False), # Internal User checks - disallowed routes ("/organization/member_add", "internal_user", False), ], @@ -340,3 +342,41 @@ def test_is_ui_route_allowed(route, user_role, expected_result): pass else: raise e + + +@pytest.mark.parametrize( + "route, user_role, expected_result", + [ + ("/key/generate", "internal_user_viewer", False), + ], +) +def test_is_api_route_allowed(route, user_role, expected_result): + from litellm.proxy.auth.user_api_key_auth import _is_api_route_allowed + from litellm.proxy._types import LiteLLM_UserTable + + user_obj = LiteLLM_UserTable( + user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297", + max_budget=None, + spend=0.0, + model_max_budget={}, + model_spend={}, + user_email="my-test-email@1234.com", + models=[], + tpm_limit=None, + rpm_limit=None, + user_role=user_role, + organization_memberships=[], + ) + + received_args: dict = { + "route": route, + "user_obj": user_obj, + } + try: + assert _is_api_route_allowed(**received_args) == expected_result + except Exception as e: + # If expected result is False, we expect an error + if expected_result is False: + pass + else: + raise e