LiteLLM Minor Fixes & Improvements (10/30/2024) (#6519)

* refactor: move gemini translation logic inside the transformation.py file

easier to isolate the gemini translation logic

* fix(gemini-transformation): support multiple tool calls in message body

Merges https://github.com/BerriAI/litellm/pull/6487/files

* test(test_vertex.py): add remaining tests from https://github.com/BerriAI/litellm/pull/6487

* fix(gemini-transformation): return tool calls for multiple tool calls

* fix: support passing logprobs param for vertex + gemini

* feat(vertex_ai): add logprobs support for gemini calls

* fix(anthropic/chat/transformation.py): fix disable parallel tool use flag

* fix: fix linting error

* fix(_logging.py): log stacktrace information in json logs

Closes https://github.com/BerriAI/litellm/issues/6497

* fix(utils.py): fix mem leak for async stream + completion

Uses a global executor pool instead of creating a new thread on each request

Fixes https://github.com/BerriAI/litellm/issues/6404

* fix(factory.py): handle tool call + content in assistant message for bedrock

* fix: fix import

* fix(factory.py): maintain support for content as a str in assistant response

* fix: fix import

* test: cleanup test

* fix(vertex_and_google_ai_studio/): return none for content if no str value

* test: retry flaky tests

* (UI) Fix viewing members, keys in a team + added testing  (#6514)

* fix listing teams on ui

* LiteLLM Minor Fixes & Improvements (10/28/2024)  (#6475)

* fix(anthropic/chat/transformation.py): support anthropic disable_parallel_tool_use param

Fixes https://github.com/BerriAI/litellm/issues/6456

* feat(anthropic/chat/transformation.py): support anthropic computer tool use

Closes https://github.com/BerriAI/litellm/issues/6427

* fix(vertex_ai/common_utils.py): parse out '$schema' when calling vertex ai

Fixes issue when trying to call vertex from vercel sdk

* fix(main.py): add 'extra_headers' support for azure on all translation endpoints

Fixes https://github.com/BerriAI/litellm/issues/6465

* fix: fix linting errors

* fix(transformation.py): handle no beta headers for anthropic

* test: cleanup test

* fix: fix linting error

* fix: fix linting errors

* fix: fix linting errors

* fix(transformation.py): handle dummy tool call

* fix(main.py): fix linting error

* fix(azure.py): pass required param

* LiteLLM Minor Fixes & Improvements (10/24/2024) (#6441)

* fix(azure.py): handle /openai/deployment in azure api base

* fix(factory.py): fix faulty anthropic tool result translation check

Fixes https://github.com/BerriAI/litellm/issues/6422

* fix(gpt_transformation.py): add support for parallel_tool_calls to azure

Fixes https://github.com/BerriAI/litellm/issues/6440

* fix(factory.py): support anthropic prompt caching for tool results

* fix(vertex_ai/common_utils): don't pop non-null required field

Fixes https://github.com/BerriAI/litellm/issues/6426

* feat(vertex_ai.py): support code_execution tool call for vertex ai + gemini

Closes https://github.com/BerriAI/litellm/issues/6434

* build(model_prices_and_context_window.json): Add 'supports_assistant_prefill' for bedrock claude-3-5-sonnet v2 models

Closes https://github.com/BerriAI/litellm/issues/6437

* fix(types/utils.py): fix linting

* test: update test to include required fields

* test: fix test

* test: handle flaky test

* test: remove e2e test - hitting gemini rate limits

* Litellm dev 10 26 2024 (#6472)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* (Testing) Add unit testing for DualCache - ensure in memory cache is used when expected  (#6471)

* test test_dual_cache_get_set

* unit testing for dual cache

* fix async_set_cache_sadd

* test_dual_cache_local_only

* redis otel tracing + async support for latency routing (#6452)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param

* fix(dual_cache.py): set default value for parent_otel_span

* fix(transformation.py): support 'response_format' for anthropic calls

* fix(transformation.py): check for cache_control inside 'function' block

* fix: fix linting error

* fix: fix linting errors

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>

* ui new build

* Add retry strat (#6520)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* (fix) slack alerting - don't spam the failed cost tracking alert for the same model  (#6543)

* fix use failing_model as cache key for failed_tracking_alert

* fix use standard logging payload for getting response cost

* fix  kwargs.get("response_cost")

* fix getting response cost

* (feat) add XAI ChatCompletion Support  (#6373)

* init commit for XAI

* add full logic for xai chat completion

* test_completion_xai

* docs xAI

* add xai/grok-beta

* test_xai_chat_config_get_openai_compatible_provider_info

* test_xai_chat_config_map_openai_params

* add xai streaming test

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-11-01 23:14:32 +04:00 committed by GitHub
parent 5652c375b3
commit f79365df6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 1851 additions and 700 deletions

1
.gitignore vendored
View file

@ -65,3 +65,4 @@ litellm/tests/log.txt
litellm/tests/langfuse.log
litellm/tests/langfuse.log
litellm/proxy/google-cloud-sdk/*
tests/llm_translation/log.txt

View file

@ -35,6 +35,9 @@ class JsonFormatter(Formatter):
"timestamp": self.formatTime(record),
}
if record.exc_info:
json_record["stacktrace"] = self.formatException(record.exc_info)
return json.dumps(json_record)

View file

@ -123,7 +123,7 @@ class AnthropicConfig:
return headers
def _map_tool_choice(
self, tool_choice: Optional[str], disable_parallel_tool_use: Optional[bool]
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
) -> Optional[AnthropicMessagesToolChoice]:
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if tool_choice == "auto":
@ -138,13 +138,15 @@ class AnthropicConfig:
if _tool_name is not None:
_tool_choice["name"] = _tool_name
if disable_parallel_tool_use is not None:
if parallel_tool_use is not None:
# Anthropic uses 'disable_parallel_tool_use' flag to determine if parallel tool use is allowed
# this is the inverse of the openai flag.
if _tool_choice is not None:
_tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use
_tool_choice["disable_parallel_tool_use"] = not parallel_tool_use
else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag
_tool_choice = AnthropicMessagesToolChoice(
type="auto",
disable_parallel_tool_use=disable_parallel_tool_use,
disable_parallel_tool_use=not parallel_tool_use,
)
return _tool_choice
@ -255,9 +257,7 @@ class AnthropicConfig:
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"),
disable_parallel_tool_use=non_default_params.get(
"parallel_tool_calls"
),
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
)
)

View file

@ -2552,18 +2552,20 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
elif messages[msg_i].get("content", None) is not None and isinstance(
messages[msg_i]["content"], str
):
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion [TODO]:
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
msg_i += 1

View file

@ -11,9 +11,9 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc
from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message
from ..gemini.transformation import _transform_system_message
from ..gemini.vertex_and_google_ai_studio_gemini import (
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)

View file

@ -4,14 +4,34 @@ Transformation logic from OpenAI format to Gemini format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Literal, Optional, Tuple, Union
import os
from typing import List, Literal, Optional, Tuple, Union, cast
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import response_schema_prompt
from litellm.types.llms.openai import AllMessageValues
from litellm.llms.prompt_templates.factory import (
convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result,
response_schema_prompt,
)
from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionImageObject,
ChatCompletionTextObject,
)
from litellm.types.llms.vertex_ai import *
from litellm.types.llms.vertex_ai import (
GenerationConfig,
PartType,
@ -21,9 +41,185 @@ from litellm.types.llms.vertex_ai import (
ToolConfig,
Tools,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..common_utils import get_supports_response_schema, get_supports_system_message
from ..vertex_ai_non_gemini import _gemini_convert_messages_with_history
from ..common_utils import (
_check_text_in_content,
get_supports_response_schema,
get_supports_system_message,
)
def _process_gemini_image(image_url: str) -> PartType:
try:
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
return PartType(file_data=file_data)
# Direct links
elif "https:/" in image_url or "base64" in image_url:
image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image["data"], mime_type=image["media_type"])
return PartType(inline_data=_blob)
raise Exception("Invalid image received - {}".format(image_url))
except Exception as e:
raise e
def _gemini_convert_messages_with_history( # noqa: PLR0915
messages: List[AllMessageValues],
) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
- Parts must be iterable
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
user_message_types = {"user", "system"}
contents: List[ContentType] = []
last_message_with_tool_calls = None
msg_i = 0
tool_call_responses = []
try:
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if (
element["type"] == "text"
and "text" in element
and len(element["text"]) > 0
):
element = cast(ChatCompletionTextObject, element)
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
element = cast(ChatCompletionImageObject, element)
img_element = element
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part)
user_content.extend(_parts)
elif (
_message_content is not None
and isinstance(_message_content, str)
and len(_message_content) > 0
):
_part = PartType(text=_message_content)
user_content.append(_part)
msg_i += 1
if user_content:
"""
check that user_content has 'text' parameter.
- Known Vertex Error: Unable to submit request because it must have a text parameter.
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
"""
has_text_in_content = _check_text_in_content(user_content)
if has_text_in_content is False:
verbose_logger.warning(
"No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
)
user_content.append(
PartType(text=" ")
) # add a blank text, to ensure Gemini doesn't fail the request.
contents.append(ContentType(role="user", parts=user_content))
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i], BaseModel):
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
else:
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
_message_content = assistant_msg.get("content", None)
if _message_content is not None and isinstance(_message_content, list):
_parts = []
for element in _message_content:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
assistant_content.extend(_parts)
elif (
_message_content is not None
and isinstance(_message_content, str)
and _message_content
):
assistant_text = _message_content # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore
## HANDLE ASSISTANT FUNCTION CALL
if (
assistant_msg.get("tool_calls", []) is not None
or assistant_msg.get("function_call") is not None
): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(assistant_msg)
)
last_message_with_tool_calls = assistant_msg
msg_i += 1
if assistant_content:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
tool_call_message_roles = ["tool", "function"]
if (
msg_i < len(messages)
and messages[msg_i]["role"] in tool_call_message_roles
):
_part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls # type: ignore
)
msg_i += 1
tool_call_responses.append(_part)
if msg_i < len(messages) and (
messages[msg_i]["role"] not in tool_call_message_roles
):
if len(tool_call_responses) > 0:
contents.append(ContentType(parts=tool_call_responses))
tool_call_responses = []
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
if len(tool_call_responses) > 0:
contents.append(ContentType(parts=tool_call_responses))
return contents
except Exception as e:
raise e
def _transform_request_body(

View file

@ -35,13 +35,6 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.prompt_templates.factory import (
convert_url_to_base64,
response_schema_prompt,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
@ -57,6 +50,7 @@ from litellm.types.llms.vertex_ai import (
GenerateContentResponseBody,
GenerationConfig,
HttpxPartType,
LogprobsResult,
PartType,
RequestBody,
SafetSettingsConfig,
@ -64,7 +58,12 @@ from litellm.types.llms.vertex_ai import (
ToolConfig,
Tools,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import (
ChatCompletionTokenLogprob,
ChoiceLogprobs,
GenericStreamingChunk,
TopLogprob,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
@ -365,6 +364,7 @@ class VertexGeminiConfig:
"presence_penalty",
"extra_headers",
"seed",
"logprobs",
]
def map_tool_choice_values(
@ -454,6 +454,16 @@ class VertexGeminiConfig:
_tools["code_execution"] = code_execution
return [_tools]
def _map_response_schema(self, value: dict) -> dict:
old_schema = deepcopy(value)
if isinstance(old_schema, list):
for item in old_schema:
if isinstance(item, dict):
item = _build_vertex_schema(parameters=item)
elif isinstance(old_schema, dict):
old_schema = _build_vertex_schema(parameters=old_schema)
return old_schema
def map_openai_params(
self,
model: str,
@ -461,6 +471,7 @@ class VertexGeminiConfig:
optional_params: dict,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
@ -499,19 +510,15 @@ class VertexGeminiConfig:
if "response_schema" in optional_params and isinstance(
optional_params["response_schema"], dict
):
old_schema = deepcopy(optional_params["response_schema"])
if isinstance(old_schema, list):
for item in old_schema:
if isinstance(item, dict):
item = _build_vertex_schema(parameters=item)
elif isinstance(old_schema, dict):
old_schema = _build_vertex_schema(parameters=old_schema)
optional_params["response_schema"] = old_schema
optional_params["response_schema"] = self._map_response_schema(
value=optional_params["response_schema"]
)
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "logprobs":
optional_params["responseLogprobs"] = value
if (param == "tools" or param == "functions") and isinstance(value, list):
optional_params["tools"] = self._map_function(value=value)
optional_params["litellm_param_is_function_call"] = (
@ -527,6 +534,7 @@ class VertexGeminiConfig:
optional_params["tool_choice"] = _tool_choice_value
if param == "seed":
optional_params["seed"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
@ -584,12 +592,325 @@ class VertexGeminiConfig:
)
return exception_string
def get_assistant_content_message(self, parts: List[HttpxPartType]) -> str:
content_str = ""
def get_assistant_content_message(
self, parts: List[HttpxPartType]
) -> Optional[str]:
_content_str = ""
for part in parts:
if "text" in part:
content_str += part["text"]
return content_str
_content_str += part["text"]
if _content_str:
return _content_str
return None
def _transform_parts(
self,
parts: List[HttpxPartType],
index: int,
is_function_call: Optional[bool],
) -> Tuple[
Optional[ChatCompletionToolCallFunctionChunk],
Optional[List[ChatCompletionToolCallChunk]],
]:
function: Optional[ChatCompletionToolCallFunctionChunk] = None
_tools: List[ChatCompletionToolCallChunk] = []
for part in parts:
if "functionCall" in part:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=part["functionCall"]["name"],
arguments=json.dumps(part["functionCall"]["args"]),
)
if is_function_call is True:
function = _function_chunk
else:
_tool_response_chunk = ChatCompletionToolCallChunk(
id=f"call_{str(uuid.uuid4())}",
type="function",
function=_function_chunk,
index=index,
)
_tools.append(_tool_response_chunk)
if len(_tools) == 0:
tools: Optional[List[ChatCompletionToolCallChunk]] = None
else:
tools = _tools
return function, tools
def _transform_logprobs(
self, logprobs_result: Optional[LogprobsResult]
) -> Optional[ChoiceLogprobs]:
if logprobs_result is None:
return None
if "chosenCandidates" not in logprobs_result:
return None
logprobs_list: List[ChatCompletionTokenLogprob] = []
for index, candidate in enumerate(logprobs_result["chosenCandidates"]):
top_logprobs: List[TopLogprob] = []
if "topCandidates" in logprobs_result and index < len(
logprobs_result["topCandidates"]
):
top_candidates_for_index = logprobs_result["topCandidates"][index][
"candidates"
]
for options in top_candidates_for_index:
top_logprobs.append(
TopLogprob(
token=options["token"], logprob=options["logProbability"]
)
)
logprobs_list.append(
ChatCompletionTokenLogprob(
token=candidate["token"],
logprob=candidate["logProbability"],
top_logprobs=top_logprobs,
)
)
return ChoiceLogprobs(content=logprobs_list)
def _handle_blocked_response(
self,
model_response: ModelResponse,
completion_response: GenerateContentResponseBody,
) -> ModelResponse:
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
)
setattr(model_response, "usage", usage)
return model_response
def _handle_content_policy_violation(
self,
model_response: ModelResponse,
completion_response: GenerateContentResponseBody,
) -> ModelResponse:
## CONTENT POLICY VIOLATION ERROR
model_response.choices[0].finish_reason = "content_filter"
_chat_completion_message = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=_chat_completion_message,
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
)
setattr(model_response, "usage", usage)
return model_response
def _transform_response(
self,
model: str,
response: httpx.Response,
model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
litellm_params: dict,
api_key: str,
data: Union[dict, str, RequestBody],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
## GET MODEL ##
model_response.model = model
## CHECK IF RESPONSE FLAGGED
if (
"promptFeedback" in completion_response
and "blockReason" in completion_response["promptFeedback"]
):
return self._handle_blocked_response(
model_response=model_response,
completion_response=completion_response,
)
_candidates = completion_response.get("candidates")
if _candidates and len(_candidates) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
)
if (
"finishReason" in _candidates[0]
and _candidates[0]["finishReason"] in content_policy_violations.keys()
):
return self._handle_content_policy_violation(
model_response=model_response,
completion_response=completion_response,
)
model_response.choices = [] # type: ignore
try:
## CHECK IF GROUNDING METADATA IN REQUEST
grounding_metadata: List[dict] = []
safety_ratings: List = []
citation_metadata: List = []
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant"
}
chat_completion_logprobs: Optional[ChoiceLogprobs] = None
tools: Optional[List[ChatCompletionToolCallChunk]] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
if _candidates:
for idx, candidate in enumerate(_candidates):
if "content" not in candidate:
continue
if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"])
if "citationMetadata" in candidate:
citation_metadata.append(candidate["citationMetadata"])
if "parts" in candidate["content"]:
chat_completion_message[
"content"
] = VertexGeminiConfig().get_assistant_content_message(
parts=candidate["content"]["parts"]
)
functions, tools = self._transform_parts(
parts=candidate["content"]["parts"],
index=candidate.get("index", idx),
is_function_call=litellm_params.get(
"litellm_param_is_function_call"
),
)
if "logprobsResult" in candidate:
chat_completion_logprobs = self._transform_logprobs(
logprobs_result=candidate["logprobsResult"]
)
if tools:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=chat_completion_message, # type: ignore
logprobs=chat_completion_logprobs,
enhancements=None,
)
model_response.choices.append(choice)
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
setattr(model_response, "usage", usage)
## ADD GROUNDING METADATA ##
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
model_response._hidden_params[
"vertex_ai_grounding_metadata"
] = ( # older approach - maintaining to prevent regressions
grounding_metadata
)
## ADD SAFETY RATINGS ##
setattr(model_response, "vertex_ai_safety_results", safety_ratings)
model_response._hidden_params["vertex_ai_safety_results"] = (
safety_ratings # older approach - maintaining to prevent regressions
)
## ADD CITATION METADATA ##
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
model_response._hidden_params["vertex_ai_citation_metadata"] = (
citation_metadata # older approach - maintaining to prevent regressions
)
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
completion_response, str(e)
),
status_code=422,
)
return model_response
class GoogleAIStudioGeminiConfig(
@ -675,6 +996,7 @@ class GoogleAIStudioGeminiConfig(
"response_format",
"n",
"stop",
"logprobs",
]
def map_openai_params(
@ -771,243 +1093,6 @@ class VertexLLM(VertexBase):
def __init__(self) -> None:
super().__init__()
def _process_response( # noqa: PLR0915
self,
model: str,
response: httpx.Response,
model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
litellm_params: dict,
api_key: str,
data: Union[dict, str, RequestBody],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
## GET MODEL ##
model_response.model = model
## CHECK IF RESPONSE FLAGGED
if "promptFeedback" in completion_response:
if "blockReason" in completion_response["promptFeedback"]:
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
setattr(model_response, "usage", usage)
return model_response
_candidates = completion_response.get("candidates")
if _candidates and len(_candidates) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
)
if (
"finishReason" in _candidates[0]
and _candidates[0]["finishReason"] in content_policy_violations.keys()
):
## CONTENT POLICY VIOLATION ERROR
model_response.choices[0].finish_reason = "content_filter"
_chat_completion_message = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=_chat_completion_message,
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
setattr(model_response, "usage", usage)
return model_response
model_response.choices = [] # type: ignore
try:
## CHECK IF GROUNDING METADATA IN REQUEST
grounding_metadata: List[dict] = []
safety_ratings: List = []
citation_metadata: List = []
## GET TEXT ##
chat_completion_message = {"role": "assistant"}
content_str: str = ""
tools: List[ChatCompletionToolCallChunk] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
if _candidates:
for idx, candidate in enumerate(_candidates):
if "content" not in candidate:
continue
if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"])
if "citationMetadata" in candidate:
citation_metadata.append(candidate["citationMetadata"])
if "parts" in candidate["content"]:
content_str = (
VertexGeminiConfig().get_assistant_content_message(
parts=candidate["content"]["parts"]
)
)
if (
"parts" in candidate["content"]
and "functionCall" in candidate["content"]["parts"][0]
):
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"][
"name"
],
arguments=json.dumps(
candidate["content"]["parts"][0]["functionCall"]["args"]
),
)
if litellm_params.get("litellm_param_is_function_call") is True:
functions = _function_chunk
else:
_tool_response_chunk = ChatCompletionToolCallChunk(
id=f"call_{str(uuid.uuid4())}",
type="function",
function=_function_chunk,
index=candidate.get("index", idx),
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = (
content_str if len(content_str) > 0 else None
)
if len(tools) > 0:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices.append(choice)
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"].get(
"promptTokenCount", 0
),
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"].get(
"totalTokenCount", 0
),
)
setattr(model_response, "usage", usage)
## ADD GROUNDING METADATA ##
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
model_response._hidden_params[
"vertex_ai_grounding_metadata"
] = ( # older approach - maintaining to prevent regressions
grounding_metadata
)
## ADD SAFETY RATINGS ##
setattr(model_response, "vertex_ai_safety_results", safety_ratings)
model_response._hidden_params["vertex_ai_safety_results"] = (
safety_ratings # older approach - maintaining to prevent regressions
)
## ADD CITATION METADATA ##
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
model_response._hidden_params["vertex_ai_citation_metadata"] = (
citation_metadata # older approach - maintaining to prevent regressions
)
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
completion_response, str(e)
),
status_code=422,
)
return model_response
async def async_streaming(
self,
model: str,
@ -1171,7 +1256,7 @@ class VertexLLM(VertexBase):
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
return VertexGeminiConfig()._transform_response(
model=model,
response=response,
model_response=model_response,
@ -1359,7 +1444,7 @@ class VertexLLM(VertexBase):
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
return VertexGeminiConfig()._transform_response(
model=model,
response=response,
model_response=model_response,

View file

@ -85,185 +85,6 @@ class TextStreamer:
raise StopAsyncIteration # once we run out of data to stream, we raise this error
def _get_image_bytes_from_url(image_url: str) -> bytes:
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
image_bytes = response.content
return image_bytes
except requests.exceptions.RequestException as e:
raise Exception(f"An exception occurs with this image - {str(e)}")
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
if role == "user":
return "user"
else:
return "model"
def _process_gemini_image(image_url: str) -> PartType:
try:
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
return PartType(file_data=file_data)
# Direct links
elif "https:/" in image_url or "base64" in image_url:
image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image["data"], mime_type=image["media_type"])
return PartType(inline_data=_blob)
raise Exception("Invalid image received - {}".format(image_url))
except Exception as e:
raise e
def _gemini_convert_messages_with_history( # noqa: PLR0915
messages: List[AllMessageValues],
) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
- Parts must be iterable
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
user_message_types = {"user", "system"}
contents: List[ContentType] = []
last_message_with_tool_calls = None
msg_i = 0
try:
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if (
element["type"] == "text"
and "text" in element
and len(element["text"]) > 0
):
element = cast(ChatCompletionTextObject, element)
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
element = cast(ChatCompletionImageObject, element)
img_element = element
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part)
user_content.extend(_parts)
elif (
_message_content is not None
and isinstance(_message_content, str)
and len(_message_content) > 0
):
_part = PartType(text=_message_content)
user_content.append(_part)
msg_i += 1
if user_content:
"""
check that user_content has 'text' parameter.
- Known Vertex Error: Unable to submit request because it must have a text parameter.
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
"""
has_text_in_content = _check_text_in_content(user_content)
if has_text_in_content is False:
verbose_logger.warning(
"No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
)
user_content.append(
PartType(text=" ")
) # add a blank text, to ensure Gemini doesn't fail the request.
contents.append(ContentType(role="user", parts=user_content))
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i], BaseModel):
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
else:
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
_message_content = assistant_msg.get("content", None)
if _message_content is not None and isinstance(_message_content, list):
_parts = []
for element in _message_content:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
assistant_content.extend(_parts)
elif (
_message_content is not None
and isinstance(_message_content, str)
and _message_content
):
assistant_text = _message_content # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore
## HANDLE ASSISTANT FUNCTION CALL
if (
assistant_msg.get("tool_calls", []) is not None
or assistant_msg.get("function_call") is not None
): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(assistant_msg)
)
last_message_with_tool_calls = assistant_msg
msg_i += 1
if assistant_content:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and (
messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function"
):
_part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls # type: ignore
)
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
except Exception as e:
raise e
def _get_client_cache_key(
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
):
@ -487,91 +308,7 @@ def completion( # noqa: PLR0915
return async_completion(**data)
completion_response = None
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream is True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
_model_response = llm_model.generate_content(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
stream=True,
tools=tools,
)
return _model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
response = llm_model.generate_content(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
tools=tools,
)
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
# Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
if mode == "chat":
chat = llm_model.start_chat()
request_str += "chat = llm_model.start_chat()\n"
@ -796,82 +533,7 @@ async def async_completion( # noqa: PLR0915
response_obj = None
completion_response = None
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
optional_params.pop("stream", False)
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
# print(f"final content: {content}")
response = await llm_model._generate_content_async(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
tools=tools,
)
_cache_key = _get_client_cache_key(
model=model,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
_set_client_in_cache(
client_cache_key=_cache_key, vertex_llm_model=llm_model
)
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
# Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
completion_response = message
else:
completion_response = response.text
response_obj = response._raw_response
optional_params["tools"] = tools
elif mode == "chat":
if mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
@ -1032,32 +694,7 @@ async def async_streaming( # noqa: PLR0915
Add support for async streaming calls for gemini-pro
"""
response: Any = None
if mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await llm_model._generate_content_streaming_async(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
tools=tools,
)
elif mode == "chat":
if mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
"stream", None

View file

@ -15,7 +15,6 @@ import json
import os
import random
import sys
import threading
import time
import traceback
import uuid

View file

@ -10,10 +10,9 @@ model_list:
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: my-custom-model
- model_name: gemini-1.5-flash-002
litellm_params:
model: my-custom-llm/my-custom-model
api_key: my-fake-key
model: gemini/gemini-1.5-flash-002
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]

View file

@ -97,9 +97,9 @@ class PassThroughEndpointLogging:
if "generateContent" in url_route:
model = self.extract_model_from_url(url_route)
instance_of_vertex_llm = VertexLLM()
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._process_response(
instance_of_vertex_llm._transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}

View file

@ -1,33 +0,0 @@
import os
import sys
import time
import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from langtrace_python_sdk import langtrace
import litellm
sys.path.insert(0, os.path.abspath("../.."))
@pytest.fixture()
def exporter():
exporter = InMemorySpanExporter()
langtrace.init(batch=False, custom_remote_exporter=exporter)
litellm.success_callback = ["langtrace"]
litellm.set_verbose = True
return exporter
@pytest.mark.parametrize("model", ["claude-2.1", "gpt-3.5-turbo"])
def test_langtrace_logging(exporter, model):
litellm.completion(
model=model,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=1000,
temperature=0.7,
timeout=5,
mock_response="hi",
)

View file

@ -500,9 +500,9 @@ ChatCompletionAssistantContentValue = (
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[ChatCompletionAssistantContentValue]
tool_calls: List[ChatCompletionToolCallChunk]
tool_calls: Optional[List[ChatCompletionToolCallChunk]]
role: Literal["assistant"]
function_call: ChatCompletionToolCallFunctionChunk
function_call: Optional[ChatCompletionToolCallFunctionChunk]
class ChatCompletionUsageBlock(TypedDict):

View file

@ -167,6 +167,8 @@ class GenerationConfig(TypedDict, total=False):
response_mime_type: Literal["text/plain", "application/json"]
response_schema: dict
seed: int
responseLogprobs: bool
logprobs: int
class Tools(TypedDict, total=False):
@ -270,6 +272,21 @@ class GroundingMetadata(TypedDict, total=False):
groundingAttributions: List[dict]
class LogprobsCandidate(TypedDict):
token: str
tokenId: int
logProbability: float
class LogprobsTopCandidate(TypedDict):
candidates: List[LogprobsCandidate]
class LogprobsResult(TypedDict, total=False):
topCandidates: List[LogprobsTopCandidate]
chosenCandidates: List[LogprobsCandidate]
class Candidates(TypedDict, total=False):
index: int
content: HttpxContentType
@ -288,6 +305,7 @@ class Candidates(TypedDict, total=False):
citationMetadata: CitationMetadata
groundingMetadata: GroundingMetadata
finishMessage: str
logprobsResult: LogprobsResult
class PromptFeedback(TypedDict):

View file

@ -7708,10 +7708,17 @@ class CustomStreamWrapper:
continue
## LOGGING
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
result=processed_chunk,
start_time=None,
end_time=None,
cache_hit=cache_hit,
)
# threading.Thread(
# target=self.logging_obj.success_handler,
# args=(processed_chunk, None, None, cache_hit),
# ).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit

View file

@ -794,7 +794,7 @@ def test_anthropic_parallel_tool_calls(provider):
parallel_tool_calls=True,
)
print(f"optional_params: {optional_params}")
assert optional_params["tool_choice"]["disable_parallel_tool_use"] is True
assert optional_params["tool_choice"]["disable_parallel_tool_use"] is False
def test_anthropic_computer_tool_use():

View file

@ -230,3 +230,944 @@ def test_function_calling_with_gemini():
]
}
]
def test_multiple_function_call():
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
messages = [
{"role": "user", "content": [{"type": "text", "text": "do test"}]},
{
"role": "assistant",
"content": [{"type": "text", "text": "test"}],
"tool_calls": [
{
"index": 0,
"function": {"arguments": '{"arg": "test"}', "name": "test"},
"id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec",
"type": "function",
},
{
"index": 1,
"function": {"arguments": '{"arg": "test2"}', "name": "test2"},
"id": "call_2414e8f9-283a-002b-182a-1290ab912c02",
"type": "function",
},
],
},
{
"tool_call_id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec",
"role": "tool",
"name": "test",
"content": [{"type": "text", "text": "42"}],
},
{
"tool_call_id": "call_2414e8f9-283a-002b-182a-1290ab912c02",
"role": "tool",
"name": "test2",
"content": [{"type": "text", "text": "15"}],
},
{"role": "user", "content": [{"type": "text", "text": "tell me the results."}]},
]
response_body = {
"candidates": [
{
"content": {
"parts": [
{
"text": 'The `default_api.test` function call returned a JSON object indicating a successful execution. The `fields` key contains a nested dictionary with a `key` of "content" and a `value` with a `string_value` of "42".\n\nSimilarly, the `default_api.test2` function call also returned a JSON object showing successful execution. The `fields` key contains a nested dictionary with a `key` of "content" and a `value` with a `string_value` of "15".\n\nIn short, both test functions executed successfully and returned different numerical string values ("42" and "15"). The significance of these numbers depends on the internal logic of the `test` and `test2` functions within the `default_api`.\n'
}
],
"role": "model",
},
"finishReason": "STOP",
"avgLogprobs": -0.20577410289219447,
}
],
"usageMetadata": {
"promptTokenCount": 128,
"candidatesTokenCount": 168,
"totalTokenCount": 296,
},
"modelVersion": "gemini-1.5-flash-002",
}
mock_response = MagicMock()
mock_response.json.return_value = response_body
with patch.object(client, "post", return_value=mock_response) as mock_post:
r = litellm.completion(
messages=messages, model="gemini/gemini-1.5-flash-002", client=client
)
assert len(r.choices) > 0
assert mock_post.call_args.kwargs["json"] == {
"contents": [
{"role": "user", "parts": [{"text": "do test"}]},
{
"role": "model",
"parts": [
{"text": "test"},
{
"function_call": {
"name": "test",
"args": {
"fields": {
"key": "arg",
"value": {"string_value": "test"},
}
},
}
},
{
"function_call": {
"name": "test2",
"args": {
"fields": {
"key": "arg",
"value": {"string_value": "test2"},
}
},
}
},
],
},
{
"parts": [
{
"function_response": {
"name": "test",
"response": {
"fields": {
"key": "content",
"value": {"string_value": "42"},
}
},
}
},
{
"function_response": {
"name": "test2",
"response": {
"fields": {
"key": "content",
"value": {"string_value": "15"},
}
},
}
},
]
},
{"role": "user", "parts": [{"text": "tell me the results."}]},
],
"generationConfig": {},
}
def test_multiple_function_call_changed_text_pos():
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
messages = [
{"role": "user", "content": [{"type": "text", "text": "do test"}]},
{
"tool_calls": [
{
"index": 0,
"function": {"arguments": '{"arg": "test"}', "name": "test"},
"id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec",
"type": "function",
},
{
"index": 1,
"function": {"arguments": '{"arg": "test2"}', "name": "test2"},
"id": "call_2414e8f9-283a-002b-182a-1290ab912c02",
"type": "function",
},
],
"role": "assistant",
"content": [{"type": "text", "text": "test"}],
},
{
"tool_call_id": "call_2414e8f9-283a-002b-182a-1290ab912c02",
"role": "tool",
"name": "test2",
"content": [{"type": "text", "text": "15"}],
},
{
"tool_call_id": "call_597e00e6-11d4-4ed2-94b2-27edee250aec",
"role": "tool",
"name": "test",
"content": [{"type": "text", "text": "42"}],
},
{"role": "user", "content": [{"type": "text", "text": "tell me the results."}]},
]
response_body = {
"candidates": [
{
"content": {
"parts": [
{
"text": 'The code executed two functions, `test` and `test2`.\n\n* **`test`**: Returned a dictionary indicating that the "key" field has a "value" field containing a string value of "42". This is likely a response from a function that processed the input "test" and returned a calculated or pre-defined value.\n\n* **`test2`**: Returned a dictionary indicating that the "key" field has a "value" field containing a string value of "15". Similar to `test`, this suggests a function that processes the input "test2" and returns a specific result.\n\nIn short, both functions appear to be simple tests that return different hardcoded or calculated values based on their input arguments.\n'
}
],
"role": "model",
},
"finishReason": "STOP",
"avgLogprobs": -0.32848488592332409,
}
],
"usageMetadata": {
"promptTokenCount": 128,
"candidatesTokenCount": 155,
"totalTokenCount": 283,
},
"modelVersion": "gemini-1.5-flash-002",
}
mock_response = MagicMock()
mock_response.json.return_value = response_body
with patch.object(client, "post", return_value=mock_response) as mock_post:
resp = litellm.completion(
messages=messages, model="gemini/gemini-1.5-flash-002", client=client
)
assert len(resp.choices) > 0
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["json"]["contents"] == [
{"role": "user", "parts": [{"text": "do test"}]},
{
"role": "model",
"parts": [
{"text": "test"},
{
"function_call": {
"name": "test",
"args": {
"fields": {
"key": "arg",
"value": {"string_value": "test"},
}
},
}
},
{
"function_call": {
"name": "test2",
"args": {
"fields": {
"key": "arg",
"value": {"string_value": "test2"},
}
},
}
},
],
},
{
"parts": [
{
"function_response": {
"name": "test2",
"response": {
"fields": {
"key": "content",
"value": {"string_value": "15"},
}
},
}
},
{
"function_response": {
"name": "test",
"response": {
"fields": {
"key": "content",
"value": {"string_value": "42"},
}
},
}
},
]
},
{"role": "user", "parts": [{"text": "tell me the results."}]},
]
def test_function_calling_with_gemini_multiple_results():
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
# Step 1: send the conversation and available functions to the model
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
response_body = {
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "get_current_weather",
"args": {"location": "San Francisco"},
}
},
{
"functionCall": {
"name": "get_current_weather",
"args": {"location": "Tokyo"},
}
},
{
"functionCall": {
"name": "get_current_weather",
"args": {"location": "Paris"},
}
},
],
"role": "model",
},
"finishReason": "STOP",
"avgLogprobs": -0.0040788948535919189,
}
],
"usageMetadata": {
"promptTokenCount": 90,
"candidatesTokenCount": 22,
"totalTokenCount": 112,
},
"modelVersion": "gemini-1.5-flash-002",
}
mock_response = MagicMock()
mock_response.json.return_value = response_body
with patch.object(client, "post", return_value=mock_response):
response = litellm.completion(
model="gemini/gemini-1.5-flash-002",
messages=messages,
tools=tools,
tool_choice="required",
client=client,
)
print("Response\n", response)
assert len(response.choices[0].message.tool_calls) == 3
expected_locations = ["San Francisco", "Tokyo", "Paris"]
for idx, tool_call in enumerate(response.choices[0].message.tool_calls):
json_args = json.loads(tool_call.function.arguments)
assert json_args["location"] == expected_locations[idx]
def test_logprobs_unit_test():
from litellm import VertexGeminiConfig
result = VertexGeminiConfig()._transform_logprobs(
logprobs_result={
"topCandidates": [
{
"candidates": [
{"token": "```", "logProbability": -1.5496514e-06},
{"token": "`", "logProbability": -13.375002},
{"token": "``", "logProbability": -21.875002},
]
},
{
"candidates": [
{"token": "tool", "logProbability": 0},
{"token": "too", "logProbability": -29.031433},
{"token": "to", "logProbability": -34.11199},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "code", "logProbability": 0},
{"token": "co", "logProbability": -28.114716},
{"token": "c", "logProbability": -29.283161},
]
},
{
"candidates": [
{"token": "\n", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "print", "logProbability": 0},
{"token": "p", "logProbability": -19.7494},
{"token": "prin", "logProbability": -21.117342},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "default", "logProbability": 0},
{"token": "get", "logProbability": -16.811178},
{"token": "ge", "logProbability": -19.031078},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "api", "logProbability": 0},
{"token": "ap", "logProbability": -26.501019},
{"token": "a", "logProbability": -30.905857},
]
},
{
"candidates": [
{"token": ".", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "get", "logProbability": 0},
{"token": "ge", "logProbability": -19.984676},
{"token": "g", "logProbability": -20.527714},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "current", "logProbability": 0},
{"token": "cur", "logProbability": -28.193565},
{"token": "cu", "logProbability": -29.636738},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "weather", "logProbability": 0},
{"token": "we", "logProbability": -27.887215},
{"token": "wea", "logProbability": -31.851082},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "location", "logProbability": 0},
{"token": "loc", "logProbability": -19.152641},
{"token": " location", "logProbability": -21.981709},
]
},
{
"candidates": [
{"token": '="', "logProbability": -0.034490786},
{"token": "='", "logProbability": -3.398928},
{"token": "=", "logProbability": -7.6194153},
]
},
{
"candidates": [
{"token": "San", "logProbability": -6.5561944e-06},
{"token": '\\"', "logProbability": -12.015556},
{"token": "Paris", "logProbability": -14.647776},
]
},
{
"candidates": [
{"token": " Francisco", "logProbability": -3.5760596e-07},
{"token": " Frans", "logProbability": -14.83527},
{"token": " francisco", "logProbability": -19.796852},
]
},
{
"candidates": [
{"token": '"))', "logProbability": -6.079254e-06},
{"token": ",", "logProbability": -12.106029},
{"token": '",', "logProbability": -14.56927},
]
},
{
"candidates": [
{"token": "\n", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "print", "logProbability": -0.04140338},
{"token": "```", "logProbability": -3.2049975},
{"token": "p", "logProbability": -22.087523},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "default", "logProbability": 0},
{"token": "get", "logProbability": -20.266342},
{"token": "de", "logProbability": -20.906395},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "api", "logProbability": 0},
{"token": "ap", "logProbability": -27.712265},
{"token": "a", "logProbability": -31.986958},
]
},
{
"candidates": [
{"token": ".", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "get", "logProbability": 0},
{"token": "g", "logProbability": -23.569286},
{"token": "ge", "logProbability": -23.829632},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "current", "logProbability": 0},
{"token": "cur", "logProbability": -30.125153},
{"token": "curr", "logProbability": -31.756569},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "weather", "logProbability": 0},
{"token": "we", "logProbability": -27.743786},
{"token": "w", "logProbability": -30.594503},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "location", "logProbability": 0},
{"token": "loc", "logProbability": -21.177715},
{"token": " location", "logProbability": -22.166002},
]
},
{
"candidates": [
{"token": '="', "logProbability": -1.5617967e-05},
{"token": "='", "logProbability": -11.080961},
{"token": "=", "logProbability": -15.164277},
]
},
{
"candidates": [
{"token": "Tokyo", "logProbability": -3.0041514e-05},
{"token": "tokyo", "logProbability": -10.650261},
{"token": "Paris", "logProbability": -12.096886},
]
},
{
"candidates": [
{"token": '"))', "logProbability": -1.1922384e-07},
{"token": '",', "logProbability": -16.61921},
{"token": ",", "logProbability": -17.911102},
]
},
{
"candidates": [
{"token": "\n", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "print", "logProbability": -3.5760596e-07},
{"token": "```", "logProbability": -14.949171},
{"token": "p", "logProbability": -24.321035},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "default", "logProbability": 0},
{"token": "de", "logProbability": -27.885206},
{"token": "def", "logProbability": -28.40597},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "api", "logProbability": 0},
{"token": "ap", "logProbability": -25.905933},
{"token": "a", "logProbability": -30.408901},
]
},
{
"candidates": [
{"token": ".", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "get", "logProbability": 0},
{"token": "g", "logProbability": -22.274963},
{"token": "ge", "logProbability": -23.285828},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "current", "logProbability": 0},
{"token": "cur", "logProbability": -28.442535},
{"token": "curr", "logProbability": -29.95087},
]
},
{
"candidates": [
{"token": "_", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "weather", "logProbability": 0},
{"token": "we", "logProbability": -27.307909},
{"token": "w", "logProbability": -31.076736},
]
},
{
"candidates": [
{"token": "(", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "location", "logProbability": 0},
{"token": "loc", "logProbability": -21.535915},
{"token": "lo", "logProbability": -23.028284},
]
},
{
"candidates": [
{"token": '="', "logProbability": -8.821511e-06},
{"token": "='", "logProbability": -11.700986},
{"token": "=", "logProbability": -14.50358},
]
},
{
"candidates": [
{"token": "Paris", "logProbability": 0},
{"token": "paris", "logProbability": -18.07075},
{"token": "Par", "logProbability": -21.911625},
]
},
{
"candidates": [
{"token": '"))', "logProbability": 0},
{"token": '")', "logProbability": -17.916853},
{"token": ",", "logProbability": -18.318272},
]
},
{
"candidates": [
{"token": "\n", "logProbability": 0},
{"token": "ont", "logProbability": -1.2676506e30},
{"token": " п", "logProbability": -1.2676506e30},
]
},
{
"candidates": [
{"token": "```", "logProbability": -3.5763796e-06},
{"token": "print", "logProbability": -12.535343},
{"token": "``", "logProbability": -19.670813},
]
},
],
"chosenCandidates": [
{"token": "```", "logProbability": -1.5496514e-06},
{"token": "tool", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "code", "logProbability": 0},
{"token": "\n", "logProbability": 0},
{"token": "print", "logProbability": 0},
{"token": "(", "logProbability": 0},
{"token": "default", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "api", "logProbability": 0},
{"token": ".", "logProbability": 0},
{"token": "get", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "current", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "weather", "logProbability": 0},
{"token": "(", "logProbability": 0},
{"token": "location", "logProbability": 0},
{"token": '="', "logProbability": -0.034490786},
{"token": "San", "logProbability": -6.5561944e-06},
{"token": " Francisco", "logProbability": -3.5760596e-07},
{"token": '"))', "logProbability": -6.079254e-06},
{"token": "\n", "logProbability": 0},
{"token": "print", "logProbability": -0.04140338},
{"token": "(", "logProbability": 0},
{"token": "default", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "api", "logProbability": 0},
{"token": ".", "logProbability": 0},
{"token": "get", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "current", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "weather", "logProbability": 0},
{"token": "(", "logProbability": 0},
{"token": "location", "logProbability": 0},
{"token": '="', "logProbability": -1.5617967e-05},
{"token": "Tokyo", "logProbability": -3.0041514e-05},
{"token": '"))', "logProbability": -1.1922384e-07},
{"token": "\n", "logProbability": 0},
{"token": "print", "logProbability": -3.5760596e-07},
{"token": "(", "logProbability": 0},
{"token": "default", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "api", "logProbability": 0},
{"token": ".", "logProbability": 0},
{"token": "get", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "current", "logProbability": 0},
{"token": "_", "logProbability": 0},
{"token": "weather", "logProbability": 0},
{"token": "(", "logProbability": 0},
{"token": "location", "logProbability": 0},
{"token": '="', "logProbability": -8.821511e-06},
{"token": "Paris", "logProbability": 0},
{"token": '"))', "logProbability": 0},
{"token": "\n", "logProbability": 0},
{"token": "```", "logProbability": -3.5763796e-06},
],
}
)
print(result)
def test_logprobs():
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
response_body = {
"candidates": [
{
"content": {
"parts": [
{
"text": "I do not have access to real-time information, including current weather conditions. To get the current weather in San Francisco, I recommend checking a reliable weather website or app such as Google Weather, AccuWeather, or the National Weather Service.\n"
}
],
"role": "model",
},
"finishReason": "STOP",
"avgLogprobs": -0.04666396617889404,
"logprobsResult": {
"chosenCandidates": [
{"token": "I", "logProbability": -1.08472495e-05},
{"token": " do", "logProbability": -0.00012611414},
{"token": " not", "logProbability": 0},
{"token": " have", "logProbability": 0},
{"token": " access", "logProbability": -0.0008849616},
{"token": " to", "logProbability": 0},
{"token": " real", "logProbability": -1.1922384e-07},
{"token": "-", "logProbability": 0},
{"token": "time", "logProbability": 0},
{"token": " information", "logProbability": -2.2409657e-05},
{"token": ",", "logProbability": 0},
{"token": " including", "logProbability": 0},
{"token": " current", "logProbability": -0.14274147},
{"token": " weather", "logProbability": 0},
{"token": " conditions", "logProbability": -0.0056300927},
{"token": ".", "logProbability": -3.5760596e-07},
{"token": " ", "logProbability": -0.06392521},
{"token": "To", "logProbability": -2.3844768e-07},
{"token": " get", "logProbability": -0.058974747},
{"token": " the", "logProbability": 0},
{"token": " current", "logProbability": 0},
{"token": " weather", "logProbability": -2.3844768e-07},
{"token": " in", "logProbability": -2.3844768e-07},
{"token": " San", "logProbability": 0},
{"token": " Francisco", "logProbability": 0},
{"token": ",", "logProbability": 0},
{"token": " I", "logProbability": -0.6188003},
{"token": " recommend", "logProbability": -1.0370523e-05},
{"token": " checking", "logProbability": -0.00014005086},
{"token": " a", "logProbability": 0},
{"token": " reliable", "logProbability": -1.5496514e-06},
{"token": " weather", "logProbability": -8.344534e-07},
{"token": " website", "logProbability": -0.0078000566},
{"token": " or", "logProbability": -1.1922384e-07},
{"token": " app", "logProbability": 0},
{"token": " such", "logProbability": -0.9289338},
{"token": " as", "logProbability": 0},
{"token": " Google", "logProbability": -0.0046935496},
{"token": " Weather", "logProbability": 0},
{"token": ",", "logProbability": 0},
{"token": " Accu", "logProbability": 0},
{"token": "Weather", "logProbability": -0.00013909786},
{"token": ",", "logProbability": 0},
{"token": " or", "logProbability": -0.31303275},
{"token": " the", "logProbability": -0.17583296},
{"token": " National", "logProbability": -0.010806266},
{"token": " Weather", "logProbability": 0},
{"token": " Service", "logProbability": 0},
{"token": ".", "logProbability": -0.00068947335},
{"token": "\n", "logProbability": 0},
]
},
}
],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 50,
"totalTokenCount": 61,
},
"modelVersion": "gemini-1.5-flash-002",
}
mock_response = MagicMock()
mock_response.json.return_value = response_body
with patch.object(client, "post", return_value=mock_response):
resp = litellm.completion(
model="gemini/gemini-1.5-flash-002",
messages=[
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
logprobs=True,
client=client,
)
print(resp)
assert resp.choices[0].logprobs is not None

View file

@ -30,7 +30,7 @@ from litellm import (
completion_cost,
embedding,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_gemini_convert_messages_with_history,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
@ -1823,6 +1823,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
@pytest.mark.flaky(retries=3, delay=1)
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials()
litellm.set_verbose = True
try:
tools = [
{
@ -2925,7 +2926,7 @@ def test_gemini_function_call_parameter_in_messages():
def test_gemini_function_call_parameter_in_messages_2():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_gemini_convert_messages_with_history,
)

View file

@ -1879,13 +1879,16 @@ def test_bedrock_completion_test_4(modify_params):
{
"role": "assistant",
"content": [
{
"text": """<thinking>\nThe user is asking about a specific file: main.py. Based on the environment details provided, this file is located in the computer-vision/hm-open3d/src/ directory and is currently open in a VSCode tab.\n\nTo answer the question of what this file is, the most relevant tool would be the read_file tool. This will allow me to examine the contents of main.py to determine its purpose.\n\nThe read_file tool requires the "path" parameter. I can infer this path based on the environment details:\npath: "computer-vision/hm-open3d/src/main.py"\n\nSince I have the necessary parameter, I can proceed with calling the read_file tool.\n</thinking>"""
},
{
"toolUse": {
"input": {"path": "computer-vision/hm-open3d/src/main.py"},
"name": "read_file",
"toolUseId": "tooluse_qCt-KEyWQlWiyHl26spQVA",
}
}
},
],
},
{

View file

@ -473,9 +473,15 @@ def test_anthropic_function_call_with_no_schema(model):
completion(model=model, messages=messages, tools=tools, tool_choice="auto")
def test_passing_tool_result_as_list():
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-3-5-sonnet-20241022",
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
],
)
def test_passing_tool_result_as_list(model):
litellm.set_verbose = True
model = "anthropic/claude-3-5-sonnet-20241022"
messages = [
{
"content": [
@ -611,4 +617,5 @@ def test_passing_tool_result_as_list():
resp = completion(model=model, messages=messages, tools=tools)
print(resp)
if model == "claude-3-5-sonnet-20241022":
assert resp.usage.prompt_tokens_details.cached_tokens > 0

View file

@ -142,6 +142,7 @@ def prisma_client():
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=6, delay=1)
async def test_new_user_response(prisma_client):
try:
@ -2891,6 +2892,7 @@ async def test_generate_key_with_guardrails(prisma_client):
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=6, delay=1)
async def test_team_access_groups(prisma_client):
"""
Test team based model access groups

View file

@ -0,0 +1,243 @@
# import io
# import os
# import sys
# sys.path.insert(0, os.path.abspath("../.."))
# import litellm
# from memory_profiler import profile
# from litellm.utils import (
# ModelResponseIterator,
# ModelResponseListIterator,
# CustomStreamWrapper,
# )
# from litellm.types.utils import ModelResponse, Choices, Message
# import time
# import pytest
# # @app.post("/debug")
# # async def debug(body: ExampleRequest) -> str:
# # return await main_logic(body.query)
# def model_response_list_factory():
# chunks = [
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {
# "delta": {"content": "", "role": "assistant"},
# "finish_reason": None,
# "index": 0,
# }
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {"delta": {"content": "This"}, "finish_reason": None, "index": 0}
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {"delta": {"content": " is"}, "finish_reason": None, "index": 0}
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {"delta": {"content": " a"}, "finish_reason": None, "index": 0}
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {"delta": {"content": " dummy"}, "finish_reason": None, "index": 0}
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [
# {
# "delta": {"content": " response"},
# "finish_reason": None,
# "index": 0,
# }
# ],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "",
# "choices": [
# {
# "finish_reason": None,
# "index": 0,
# "content_filter_offsets": {
# "check_offset": 35159,
# "start_offset": 35159,
# "end_offset": 36150,
# },
# "content_filter_results": {
# "hate": {"filtered": False, "severity": "safe"},
# "self_harm": {"filtered": False, "severity": "safe"},
# "sexual": {"filtered": False, "severity": "safe"},
# "violence": {"filtered": False, "severity": "safe"},
# },
# }
# ],
# "created": 0,
# "model": "",
# "object": "",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [{"delta": {"content": "."}, "finish_reason": None, "index": 0}],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "chatcmpl-9SQxdH5hODqkWyJopWlaVOOUnFwlj",
# "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}],
# "created": 1716563849,
# "model": "gpt-4o-2024-05-13",
# "object": "chat.completion.chunk",
# "system_fingerprint": "fp_5f4bad809a",
# },
# {
# "id": "",
# "choices": [
# {
# "finish_reason": None,
# "index": 0,
# "content_filter_offsets": {
# "check_offset": 36150,
# "start_offset": 36060,
# "end_offset": 37029,
# },
# "content_filter_results": {
# "hate": {"filtered": False, "severity": "safe"},
# "self_harm": {"filtered": False, "severity": "safe"},
# "sexual": {"filtered": False, "severity": "safe"},
# "violence": {"filtered": False, "severity": "safe"},
# },
# }
# ],
# "created": 0,
# "model": "",
# "object": "",
# },
# ]
# chunk_list = []
# for chunk in chunks:
# new_chunk = litellm.ModelResponse(stream=True, id=chunk["id"])
# if "choices" in chunk and isinstance(chunk["choices"], list):
# new_choices = []
# for choice in chunk["choices"]:
# if isinstance(choice, litellm.utils.StreamingChoices):
# _new_choice = choice
# elif isinstance(choice, dict):
# _new_choice = litellm.utils.StreamingChoices(**choice)
# new_choices.append(_new_choice)
# new_chunk.choices = new_choices
# chunk_list.append(new_chunk)
# return ModelResponseListIterator(model_responses=chunk_list)
# async def mock_completion(*args, **kwargs):
# completion_stream = model_response_list_factory()
# return litellm.CustomStreamWrapper(
# completion_stream=completion_stream,
# model="gpt-4-0613",
# custom_llm_provider="cached_response",
# logging_obj=litellm.Logging(
# model="gpt-4-0613",
# messages=[{"role": "user", "content": "Hey"}],
# stream=True,
# call_type="completion",
# start_time=time.time(),
# litellm_call_id="12345",
# function_id="1245",
# ),
# )
# @profile
# async def main_logic() -> str:
# stream = await mock_completion()
# result = ""
# async for chunk in stream:
# result += chunk.choices[0].delta.content or ""
# return result
# import asyncio
# for _ in range(100):
# asyncio.run(main_logic())
# # @pytest.mark.asyncio
# # def test_memory_profile(capsys):
# # # Run the async function
# # result = asyncio.run(main_logic())
# # # Verify the result
# # assert result == "This is a dummy response."
# # # Capture the output
# # captured = capsys.readouterr()
# # # Print memory output for debugging
# # print("Memory Profiler Output:")
# # print(f"captured out: {captured.out}")
# # # Basic memory leak checks
# # for idx, line in enumerate(captured.out.split("\n")):
# # if idx % 2 == 0 and "MiB" in line:
# # print(f"line: {line}")
# # # mem_lines = [line for line in captured.out.split("\n") if "MiB" in line]
# # print(mem_lines)
# # # Ensure we have some memory lines
# # assert len(mem_lines) > 0, "No memory profiler output found"
# # # Optional: Add more specific memory leak detection
# # for line in mem_lines:
# # # Extract memory increment
# # parts = line.split()
# # if len(parts) >= 3:
# # try:
# # mem_increment = float(parts[2].replace("MiB", ""))
# # # Assert that memory increment is below a reasonable threshold
# # assert mem_increment < 1.0, f"Potential memory leak detected: {line}"
# # except (ValueError, IndexError):
# # pass # Skip lines that don't match expected format

View file

@ -25,7 +25,7 @@ from litellm.llms.prompt_templates.factory import (
from litellm.llms.prompt_templates.common_utils import (
get_completion_messages,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_gemini_convert_messages_with_history,
)
from unittest.mock import AsyncMock, MagicMock, patch

View file

@ -306,6 +306,8 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
("/key/delete", "internal_user", True),
("/key/generate", "internal_user", True),
("/key/82akk800000000jjsk/regenerate", "internal_user", True),
# Internal User Viewer
("/key/generate", "internal_user_viewer", False),
# Internal User checks - disallowed routes
("/organization/member_add", "internal_user", False),
],
@ -340,3 +342,41 @@ def test_is_ui_route_allowed(route, user_role, expected_result):
pass
else:
raise e
@pytest.mark.parametrize(
"route, user_role, expected_result",
[
("/key/generate", "internal_user_viewer", False),
],
)
def test_is_api_route_allowed(route, user_role, expected_result):
from litellm.proxy.auth.user_api_key_auth import _is_api_route_allowed
from litellm.proxy._types import LiteLLM_UserTable
user_obj = LiteLLM_UserTable(
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
max_budget=None,
spend=0.0,
model_max_budget={},
model_spend={},
user_email="my-test-email@1234.com",
models=[],
tpm_limit=None,
rpm_limit=None,
user_role=user_role,
organization_memberships=[],
)
received_args: dict = {
"route": route,
"user_obj": user_obj,
}
try:
assert _is_api_route_allowed(**received_args) == expected_result
except Exception as e:
# If expected result is False, we expect an error
if expected_result is False:
pass
else:
raise e