LiteLLM Minor Fixes and Improvements (09/12/2024) (#5658)

* fix(factory.py): handle tool call content as list

Fixes https://github.com/BerriAI/litellm/issues/5652

* fix(factory.py): enforce stronger typing

* fix(router.py): return model alias in /v1/model/info and /v1/model_group/info

* fix(user_api_key_auth.py): move noisy warning message to debug

cleanup logs

* fix(types.py): cleanup pydantic v2 deprecated param

Fixes https://github.com/BerriAI/litellm/issues/5649

* docs(gemini.md): show how to pass inline data to gemini api

Fixes https://github.com/BerriAI/litellm/issues/5674
This commit is contained in:
Krish Dholakia 2024-09-12 23:04:06 -07:00 committed by GitHub
parent 795047c37f
commit 4657a40ef1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 324 additions and 41 deletions

View file

@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
- id: isort
name: isort
entry: isort

View file

@ -708,6 +708,131 @@ response = await client.chat.completions.create(
</TabItem>
</Tabs>
## Usage - PDF / Videos / etc. Files
### Inline Data (e.g. audio stream)
LiteLLM follows the OpenAI format and accepts sending inline data as an encoded base64 string.
The format to follow is
```python
data:<mime_type>;base64,<encoded_data>
```
** LITELLM CALL **
```python
import litellm
from pathlib import Path
import base64
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
audio_bytes = Path("speech_vertex.mp3").read_bytes()
encoded_data = base64.b64encode(audio_bytes).decode("utf-8")
print("Audio Bytes = {}".format(audio_bytes))
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the audio."},
{
"type": "image_url",
"image_url": "data:audio/mp3;base64,{}".format(encoded_data), # 👈 SET MIME_TYPE + DATA
},
],
}
],
)
```
** Equivalent GOOGLE API CALL **
```python
# Initialize a Gemini model appropriate for your use case.
model = genai.GenerativeModel('models/gemini-1.5-flash')
# Create the prompt.
prompt = "Please summarize the audio."
# Load the samplesmall.mp3 file into a Python Blob object containing the audio
# file's bytes and then pass the prompt and the audio to Gemini.
response = model.generate_content([
prompt,
{
"mime_type": "audio/mp3",
"data": pathlib.Path('samplesmall.mp3').read_bytes()
}
])
# Output Gemini's response to the prompt and the inline audio.
print(response.text)
```
### https:// file
```python
import litellm
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the file."},
{
"type": "image_url",
"image_url": "https://storage..." # 👈 SET THE IMG URL
},
],
}
],
)
```
### gs:// file
```python
import litellm
import os
os.environ["GEMINI_API_KEY"] = ""
litellm.set_verbose = True # 👈 See Raw call
model = "gemini/gemini-1.5-flash"
response = litellm.completion(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Please summarize the file."},
{
"type": "image_url",
"image_url": "gs://..." # 👈 SET THE cloud storage bucket url
},
],
}
],
)
```
## Chat Models
:::tip

View file

@ -1131,7 +1131,14 @@ def convert_to_gemini_tool_call_result(
"content": "function result goes here",
}
"""
content = message.get("content", "")
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
name: Optional[str] = message.get("name", "") # type: ignore
# Recover name from last message with tool calls
@ -1156,10 +1163,10 @@ def convert_to_gemini_tool_call_result(
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
inferred_content_value = infer_protocol_value(value=content)
inferred_content_value = infer_protocol_value(value=content_str)
_field = litellm.types.llms.vertex_ai.Field(
key="content", value={inferred_content_value: content}
key="content", value={inferred_content_value: content_str}
)
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
@ -1174,7 +1181,7 @@ def convert_to_gemini_tool_call_result(
def convert_to_anthropic_tool_result(
message: Union[dict, ChatCompletionToolMessage, ChatCompletionFunctionMessage]
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
) -> AnthropicMessagesToolResultParam:
"""
OpenAI message with a tool result looks like:
@ -1207,21 +1214,29 @@ def convert_to_anthropic_tool_result(
]
}
"""
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
if message["role"] == "tool":
tool_call_id: str = message.get("tool_call_id") # type: ignore
content: str = message.get("content") # type: ignore
tool_message: ChatCompletionToolMessage = message
tool_call_id: str = tool_message["tool_call_id"]
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=content
type="tool_result", tool_use_id=tool_call_id, content=content_str
)
return anthropic_tool_result
if message["role"] == "function":
content = message.get("content") # type: ignore
tool_call_id = message.get("tool_call_id") or str(uuid.uuid4()) # type: ignore
function_message: ChatCompletionFunctionMessage = message
tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=content
type="tool_result", tool_use_id=tool_call_id, content=content_str
)
return anthropic_tool_result
@ -1624,7 +1639,8 @@ from litellm.types.llms.cohere import (
def convert_openai_message_to_cohere_tool_result(
message, tool_calls: List
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
tool_calls: List,
) -> ToolResultObject:
"""
OpenAI message with a tool result looks like:
@ -1660,7 +1676,15 @@ def convert_openai_message_to_cohere_tool_result(
]
},
"""
content_str: str = message.get("content", "")
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
if len(content_str) > 0:
try:
content = json.loads(content_str)
@ -1687,7 +1711,8 @@ def convert_openai_message_to_cohere_tool_result(
arguments = json.loads(arguments_str)
if message["role"] == "function":
name = message.get("name")
function_message: ChatCompletionFunctionMessage = message
name = function_message["name"]
cohere_tool_result: ToolResultObject = {
"call": CallObject(name=name, parameters=arguments),
"outputs": [content],
@ -2292,7 +2317,7 @@ def _convert_to_bedrock_tool_call_invoke(
def _convert_to_bedrock_tool_call_result(
message: dict,
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
) -> BedrockContentBlock:
"""
OpenAI message with a tool result looks like:
@ -2334,11 +2359,18 @@ def _convert_to_bedrock_tool_call_result(
"""
-
"""
content = message.get("content", "")
content_str: str = ""
if isinstance(message["content"], str):
content_str = message["content"]
elif isinstance(message["content"], List):
content_list = message["content"]
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
id = str(message.get("tool_call_id", str(uuid.uuid4())))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
tool_result = BedrockToolResultBlock(
content=[tool_result_content_block],
toolUseId=id,

View file

@ -160,7 +160,11 @@ def _gemini_convert_messages_with_history(
_part = PartType(text=element["text"]) # type: ignore
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"] # type: ignore
img_element: ChatCompletionImageObject = element # type: ignore
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
user_content.extend(_parts)

View file

@ -496,7 +496,7 @@ async def get_team_object(
if check_cache_only:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}. Create team via `/team/new` call."
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
)
# else, check db

View file

@ -568,7 +568,9 @@ async def user_api_key_auth(
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
except Exception as e:
verbose_logger.warning(e)
verbose_logger.debug(
e
) # moving from .warning to .debug as it spams logs when team missing from cache.
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore

View file

@ -3021,7 +3021,7 @@ async def startup_event():
@router.get(
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
) # if project requires model list
def model_list(
async def model_list(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""

View file

@ -25,9 +25,6 @@ class RerankResponse(BaseModel):
meta: dict # Contains api_version and billed_units
_hidden_params: dict = {}
class Config:
underscore_attrs_are_private = True
def __getitem__(self, key):
return self.__dict__[key]

View file

@ -4648,13 +4648,12 @@ class Router:
Used for accurate 'get_model_list'.
"""
returned_models: List[DeploymentTypedDict] = []
for model in self.model_list:
if model["model_name"] == model_name:
if model_alias is not None:
alias_model = copy.deepcopy(model)
alias_model["model_name"] = model_name
alias_model["model_name"] = model_alias
returned_models.append(alias_model)
else:
returned_models.append(model)

View file

@ -5,6 +5,8 @@ import traceback
from dotenv import load_dotenv
import litellm.types
load_dotenv()
import io
import os
@ -1232,3 +1234,56 @@ def test_bedrock_cross_region_inference():
max_tokens=10,
temperature=0.1,
)
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
def test_bedrock_converse_translation_tool_message():
from litellm.types.utils import ChatCompletionMessageToolCall, Function
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
},
{
"tool_call_id": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
"role": "tool",
"name": "get_current_weather",
"content": [
{
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
"type": "text",
}
],
},
]
translated_msg = _bedrock_converse_messages_pt(
messages=messages, model="", llm_provider=""
)
print(translated_msg)
assert translated_msg == [
{
"role": "user",
"content": [
{
"text": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"
},
{
"toolResult": {
"content": [
{
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'
}
],
"toolUseId": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
}
},
],
}
]

View file

@ -48,8 +48,8 @@ def get_current_weather(location, unit="fahrenheit"):
# "gpt-3.5-turbo-1106",
# "mistral/mistral-large-latest",
# "claude-3-haiku-20240307",
"gemini/gemini-1.5-pro",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "gemini/gemini-1.5-pro",
"anthropic.claude-3-sonnet-20240229-v1:0",
],
)
@pytest.mark.flaky(retries=3, delay=1)

View file

@ -1432,3 +1432,72 @@ async def test_gemini_pass_through_endpoint():
)
print(resp.body)
@pytest.mark.asyncio
async def test_proxy_model_group_alias_checks(prisma_client):
"""
Check if model group alias is returned on
`/v1/models`
`/v1/model/info`
`/v1/model_group/info`
"""
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
_model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
]
model_alias = "gpt-4"
router = litellm.Router(
model_list=_model_list,
model_group_alias={model_alias: "gpt-3.5-turbo"},
)
setattr(litellm.proxy.proxy_server, "llm_router", router)
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/v1/models")
resp = await model_list(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
assert len(resp) == 2
print(resp)
resp = await model_info_v1(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
models = resp["data"]
is_model_alias_in_list = False
for item in models:
if model_alias == item["model_name"]:
is_model_alias_in_list = True
assert is_model_alias_in_list
resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
models = resp["data"]
is_model_alias_in_list = False
for item in models:
if model_alias == item.model_group:
is_model_alias_in_list = True
assert is_model_alias_in_list

View file

@ -1,7 +1,6 @@
from typing import List, Optional, Union, Iterable
from typing import Iterable, List, Optional, Union
from pydantic import BaseModel, ConfigDict, validator
from typing_extensions import Literal, Required, TypedDict
@ -94,7 +93,7 @@ class Function(TypedDict, total=False):
class ChatCompletionToolMessageParam(TypedDict, total=False):
content: Required[str]
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
"""The contents of the tool message."""
role: Required[Literal["tool"]]
@ -105,7 +104,7 @@ class ChatCompletionToolMessageParam(TypedDict, total=False):
class ChatCompletionFunctionMessageParam(TypedDict, total=False):
content: Required[Optional[str]]
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
"""The contents of the function message."""
name: Required[str]

View file

@ -340,7 +340,7 @@ class ChatCompletionImageUrlObject(TypedDict, total=False):
class ChatCompletionImageObject(TypedDict):
type: Literal["image_url"]
image_url: ChatCompletionImageUrlObject
image_url: Union[str, ChatCompletionImageUrlObject]
class OpenAIChatCompletionUserMessage(TypedDict):
@ -368,14 +368,15 @@ class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total
class ChatCompletionToolMessage(TypedDict):
role: Literal["tool"]
content: str
content: Union[str, Iterable[ChatCompletionTextObject]]
tool_call_id: str
class ChatCompletionFunctionMessage(TypedDict):
role: Literal["function"]
content: Optional[str]
content: Optional[Union[str, Iterable[ChatCompletionTextObject]]]
name: str
tool_call_id: Optional[str]
class OpenAIChatCompletionSystemMessage(TypedDict, total=False):