LiteLLM Minor Fixes and Improvements (09/12/2024) (#5658)

* fix(factory.py): handle tool call content as list

Fixes https://github.com/BerriAI/litellm/issues/5652

* fix(factory.py): enforce stronger typing

* fix(router.py): return model alias in /v1/model/info and /v1/model_group/info

* fix(user_api_key_auth.py): move noisy warning message to debug

cleanup logs

* fix(types.py): cleanup pydantic v2 deprecated param

Fixes https://github.com/BerriAI/litellm/issues/5649

* docs(gemini.md): show how to pass inline data to gemini api

Fixes https://github.com/BerriAI/litellm/issues/5674
This commit is contained in:
Krish Dholakia 2024-09-12 23:04:06 -07:00 committed by GitHub
parent 749bac2053
commit 91c918fd70
14 changed files with 324 additions and 41 deletions

View file

@ -1,12 +1,12 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
# - id: mypy - id: mypy
# name: mypy name: mypy
# entry: python3 -m mypy --ignore-missing-imports entry: python3 -m mypy --ignore-missing-imports
# language: system language: system
# types: [python] types: [python]
# files: ^litellm/ files: ^litellm/
- id: isort - id: isort
name: isort name: isort
entry: isort entry: isort

View file

@ -708,6 +708,131 @@ response = await client.chat.completions.create(
</TabItem> </TabItem>
</Tabs> </Tabs>
## Usage - PDF / Videos / etc. Files
### Inline Data (e.g. audio stream)
LiteLLM follows the OpenAI format and accepts sending inline data as an encoded base64 string.
The format to follow is
```python
data:<mime_type>;base64,<encoded_data>
```
** LITELLM CALL **
```python
import litellm
from pathlib import Path
import base64
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
audio_bytes = Path("speech_vertex.mp3").read_bytes()
encoded_data = base64.b64encode(audio_bytes).decode("utf-8")
print("Audio Bytes = {}".format(audio_bytes))
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the audio."},
{
"type": "image_url",
"image_url": "data:audio/mp3;base64,{}".format(encoded_data), # 👈 SET MIME_TYPE + DATA
},
],
}
],
)
```
** Equivalent GOOGLE API CALL **
```python
# Initialize a Gemini model appropriate for your use case.
model = genai.GenerativeModel('models/gemini-1.5-flash')
# Create the prompt.
prompt = "Please summarize the audio."
# Load the samplesmall.mp3 file into a Python Blob object containing the audio
# file's bytes and then pass the prompt and the audio to Gemini.
response = model.generate_content([
prompt,
{
"mime_type": "audio/mp3",
"data": pathlib.Path('samplesmall.mp3').read_bytes()
}
])
# Output Gemini's response to the prompt and the inline audio.
print(response.text)
```
### https:// file
```python
import litellm
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the file."},
{
"type": "image_url",
"image_url": "https://storage..." # 👈 SET THE IMG URL
},
],
}
],
)
```
### gs:// file
```python
import litellm
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the file."},
{
"type": "image_url",
"image_url": "gs://..." # 👈 SET THE cloud storage bucket url
},
],
}
],
)
```
## Chat Models ## Chat Models
:::tip :::tip

View file

@ -1131,7 +1131,14 @@ def convert_to_gemini_tool_call_result(
"content": "function result goes here", "content": "function result goes here",
} }
""" """
content = message.get("content", "") content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
name: Optional[str] = message.get("name", "") # type: ignore name: Optional[str] = message.get("name", "") # type: ignore
# Recover name from last message with tool calls # Recover name from last message with tool calls
@ -1156,10 +1163,10 @@ def convert_to_gemini_tool_call_result(
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template
inferred_content_value = infer_protocol_value(value=content) inferred_content_value = infer_protocol_value(value=content_str)
_field = litellm.types.llms.vertex_ai.Field( _field = litellm.types.llms.vertex_ai.Field(
key="content", value={inferred_content_value: content} key="content", value={inferred_content_value: content_str}
) )
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field) _function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
@ -1174,7 +1181,7 @@ def convert_to_gemini_tool_call_result(
def convert_to_anthropic_tool_result( def convert_to_anthropic_tool_result(
message: Union[dict, ChatCompletionToolMessage, ChatCompletionFunctionMessage] message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
) -> AnthropicMessagesToolResultParam: ) -> AnthropicMessagesToolResultParam:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
@ -1207,21 +1214,29 @@ def convert_to_anthropic_tool_result(
] ]
} }
""" """
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
if message["role"] == "tool": if message["role"] == "tool":
tool_call_id: str = message.get("tool_call_id") # type: ignore tool_message: ChatCompletionToolMessage = message
content: str = message.get("content") # type: ignore tool_call_id: str = tool_message["tool_call_id"]
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template
anthropic_tool_result = AnthropicMessagesToolResultParam( anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=content type="tool_result", tool_use_id=tool_call_id, content=content_str
) )
return anthropic_tool_result return anthropic_tool_result
if message["role"] == "function": if message["role"] == "function":
content = message.get("content") # type: ignore function_message: ChatCompletionFunctionMessage = message
tool_call_id = message.get("tool_call_id") or str(uuid.uuid4()) # type: ignore tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
anthropic_tool_result = AnthropicMessagesToolResultParam( anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=content type="tool_result", tool_use_id=tool_call_id, content=content_str
) )
return anthropic_tool_result return anthropic_tool_result
@ -1624,7 +1639,8 @@ from litellm.types.llms.cohere import (
def convert_openai_message_to_cohere_tool_result( def convert_openai_message_to_cohere_tool_result(
message, tool_calls: List message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
tool_calls: List,
) -> ToolResultObject: ) -> ToolResultObject:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
@ -1660,7 +1676,15 @@ def convert_openai_message_to_cohere_tool_result(
] ]
}, },
""" """
content_str: str = message.get("content", "")
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
if len(content_str) > 0: if len(content_str) > 0:
try: try:
content = json.loads(content_str) content = json.loads(content_str)
@ -1687,7 +1711,8 @@ def convert_openai_message_to_cohere_tool_result(
arguments = json.loads(arguments_str) arguments = json.loads(arguments_str)
if message["role"] == "function": if message["role"] == "function":
name = message.get("name") function_message: ChatCompletionFunctionMessage = message
name = function_message["name"]
cohere_tool_result: ToolResultObject = { cohere_tool_result: ToolResultObject = {
"call": CallObject(name=name, parameters=arguments), "call": CallObject(name=name, parameters=arguments),
"outputs": [content], "outputs": [content],
@ -2292,7 +2317,7 @@ def _convert_to_bedrock_tool_call_invoke(
def _convert_to_bedrock_tool_call_result( def _convert_to_bedrock_tool_call_result(
message: dict, message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
) -> BedrockContentBlock: ) -> BedrockContentBlock:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
@ -2334,11 +2359,18 @@ def _convert_to_bedrock_tool_call_result(
""" """
- -
""" """
content = message.get("content", "") content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
name = message.get("name", "") name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4())) id = str(message.get("tool_call_id", str(uuid.uuid4())))
tool_result_content_block = BedrockToolResultContentBlock(text=content) tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
tool_result = BedrockToolResultBlock( tool_result = BedrockToolResultBlock(
content=[tool_result_content_block], content=[tool_result_content_block],
toolUseId=id, toolUseId=id,

View file

@ -160,7 +160,11 @@ def _gemini_convert_messages_with_history(
_part = PartType(text=element["text"]) # type: ignore _part = PartType(text=element["text"]) # type: ignore
_parts.append(_part) _parts.append(_part)
elif element["type"] == "image_url": elif element["type"] == "image_url":
image_url = element["image_url"]["url"] # type: ignore img_element: ChatCompletionImageObject = element # type: ignore
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url) _part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore _parts.append(_part) # type: ignore
user_content.extend(_parts) user_content.extend(_parts)

View file

@ -496,7 +496,7 @@ async def get_team_object(
if check_cache_only: if check_cache_only:
raise Exception( raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}. Create team via `/team/new` call." f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
) )
# else, check db # else, check db

View file

@ -568,7 +568,9 @@ async def user_api_key_auth(
if field_name in valid_token.__fields__: if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v) setattr(valid_token, field_name, v)
except Exception as e: except Exception as e:
verbose_logger.warning(e) verbose_logger.debug(
e
) # moving from .warning to .debug as it spams logs when team missing from cache.
try: try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore

View file

@ -3021,7 +3021,7 @@ async def startup_event():
@router.get( @router.get(
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"] "/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
) # if project requires model list ) # if project requires model list
def model_list( async def model_list(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """

View file

@ -25,9 +25,6 @@ class RerankResponse(BaseModel):
meta: dict # Contains api_version and billed_units meta: dict # Contains api_version and billed_units
_hidden_params: dict = {} _hidden_params: dict = {}
class Config:
underscore_attrs_are_private = True
def __getitem__(self, key): def __getitem__(self, key):
return self.__dict__[key] return self.__dict__[key]

View file

@ -4648,13 +4648,12 @@ class Router:
Used for accurate 'get_model_list'. Used for accurate 'get_model_list'.
""" """
returned_models: List[DeploymentTypedDict] = [] returned_models: List[DeploymentTypedDict] = []
for model in self.model_list: for model in self.model_list:
if model["model_name"] == model_name: if model["model_name"] == model_name:
if model_alias is not None: if model_alias is not None:
alias_model = copy.deepcopy(model) alias_model = copy.deepcopy(model)
alias_model["model_name"] = model_name alias_model["model_name"] = model_alias
returned_models.append(alias_model) returned_models.append(alias_model)
else: else:
returned_models.append(model) returned_models.append(model)

View file

@ -5,6 +5,8 @@ import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import litellm.types
load_dotenv() load_dotenv()
import io import io
import os import os
@ -1232,3 +1234,56 @@ def test_bedrock_cross_region_inference():
max_tokens=10, max_tokens=10,
temperature=0.1, temperature=0.1,
) )
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
def test_bedrock_converse_translation_tool_message():
from litellm.types.utils import ChatCompletionMessageToolCall, Function
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
},
{
"tool_call_id": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
"role": "tool",
"name": "get_current_weather",
"content": [
{
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
"type": "text",
}
],
},
]
translated_msg = _bedrock_converse_messages_pt(
messages=messages, model="", llm_provider=""
)
print(translated_msg)
assert translated_msg == [
{
"role": "user",
"content": [
{
"text": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"
},
{
"toolResult": {
"content": [
{
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'
}
],
"toolUseId": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
}
},
],
}
]

View file

@ -48,8 +48,8 @@ def get_current_weather(location, unit="fahrenheit"):
# "gpt-3.5-turbo-1106", # "gpt-3.5-turbo-1106",
# "mistral/mistral-large-latest", # "mistral/mistral-large-latest",
# "claude-3-haiku-20240307", # "claude-3-haiku-20240307",
"gemini/gemini-1.5-pro", # "gemini/gemini-1.5-pro",
# "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
], ],
) )
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)

View file

@ -1432,3 +1432,72 @@ async def test_gemini_pass_through_endpoint():
) )
print(resp.body) print(resp.body)
@pytest.mark.asyncio
async def test_proxy_model_group_alias_checks(prisma_client):
"""
Check if model group alias is returned on
`/v1/models`
`/v1/model/info`
`/v1/model_group/info`
"""
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
_model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
]
model_alias = "gpt-4"
router = litellm.Router(
model_list=_model_list,
model_group_alias={model_alias: "gpt-3.5-turbo"},
)
setattr(litellm.proxy.proxy_server, "llm_router", router)
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/v1/models")
resp = await model_list(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
assert len(resp) == 2
print(resp)
resp = await model_info_v1(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
models = resp["data"]
is_model_alias_in_list = False
for item in models:
if model_alias == item["model_name"]:
is_model_alias_in_list = True
assert is_model_alias_in_list
resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
models = resp["data"]
is_model_alias_in_list = False
for item in models:
if model_alias == item.model_group:
is_model_alias_in_list = True
assert is_model_alias_in_list

View file

@ -1,7 +1,6 @@
from typing import List, Optional, Union, Iterable from typing import Iterable, List, Optional, Union
from pydantic import BaseModel, ConfigDict, validator from pydantic import BaseModel, ConfigDict, validator
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict
@ -94,7 +93,7 @@ class Function(TypedDict, total=False):
class ChatCompletionToolMessageParam(TypedDict, total=False): class ChatCompletionToolMessageParam(TypedDict, total=False):
content: Required[str] content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
"""The contents of the tool message.""" """The contents of the tool message."""
role: Required[Literal["tool"]] role: Required[Literal["tool"]]
@ -105,7 +104,7 @@ class ChatCompletionToolMessageParam(TypedDict, total=False):
class ChatCompletionFunctionMessageParam(TypedDict, total=False): class ChatCompletionFunctionMessageParam(TypedDict, total=False):
content: Required[Optional[str]] content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
"""The contents of the function message.""" """The contents of the function message."""
name: Required[str] name: Required[str]

View file

@ -340,7 +340,7 @@ class ChatCompletionImageUrlObject(TypedDict, total=False):
class ChatCompletionImageObject(TypedDict): class ChatCompletionImageObject(TypedDict):
type: Literal["image_url"] type: Literal["image_url"]
image_url: ChatCompletionImageUrlObject image_url: Union[str, ChatCompletionImageUrlObject]
class OpenAIChatCompletionUserMessage(TypedDict): class OpenAIChatCompletionUserMessage(TypedDict):
@ -368,14 +368,15 @@ class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total
class ChatCompletionToolMessage(TypedDict): class ChatCompletionToolMessage(TypedDict):
role: Literal["tool"] role: Literal["tool"]
content: str content: Union[str, Iterable[ChatCompletionTextObject]]
tool_call_id: str tool_call_id: str
class ChatCompletionFunctionMessage(TypedDict): class ChatCompletionFunctionMessage(TypedDict):
role: Literal["function"] role: Literal["function"]
content: Optional[str] content: Optional[Union[str, Iterable[ChatCompletionTextObject]]]
name: str name: str
tool_call_id: Optional[str]
class OpenAIChatCompletionSystemMessage(TypedDict, total=False): class OpenAIChatCompletionSystemMessage(TypedDict, total=False):