mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LiteLLM Minor Fixes and Improvements (09/12/2024) (#5658)
* fix(factory.py): handle tool call content as list Fixes https://github.com/BerriAI/litellm/issues/5652 * fix(factory.py): enforce stronger typing * fix(router.py): return model alias in /v1/model/info and /v1/model_group/info * fix(user_api_key_auth.py): move noisy warning message to debug cleanup logs * fix(types.py): cleanup pydantic v2 deprecated param Fixes https://github.com/BerriAI/litellm/issues/5649 * docs(gemini.md): show how to pass inline data to gemini api Fixes https://github.com/BerriAI/litellm/issues/5674
This commit is contained in:
parent
749bac2053
commit
91c918fd70
14 changed files with 324 additions and 41 deletions
|
@ -1,12 +1,12 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# name: mypy
|
name: mypy
|
||||||
# entry: python3 -m mypy --ignore-missing-imports
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
# language: system
|
language: system
|
||||||
# types: [python]
|
types: [python]
|
||||||
# files: ^litellm/
|
files: ^litellm/
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
entry: isort
|
entry: isort
|
||||||
|
|
|
@ -708,6 +708,131 @@ response = await client.chat.completions.create(
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Usage - PDF / Videos / etc. Files
|
||||||
|
|
||||||
|
### Inline Data (e.g. audio stream)
|
||||||
|
|
||||||
|
LiteLLM follows the OpenAI format and accepts sending inline data as an encoded base64 string.
|
||||||
|
|
||||||
|
The format to follow is
|
||||||
|
|
||||||
|
```python
|
||||||
|
data:<mime_type>;base64,<encoded_data>
|
||||||
|
```
|
||||||
|
|
||||||
|
** LITELLM CALL **
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from pathlib import Path
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ""
|
||||||
|
|
||||||
|
litellm.set_verbose = True # 👈 See Raw call
|
||||||
|
|
||||||
|
audio_bytes = Path("speech_vertex.mp3").read_bytes()
|
||||||
|
encoded_data = base64.b64encode(audio_bytes).decode("utf-8")
|
||||||
|
print("Audio Bytes = {}".format(audio_bytes))
|
||||||
|
model = "gemini/gemini-1.5-flash"
|
||||||
|
response = litellm.completion(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Please summarize the audio."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "data:audio/mp3;base64,{}".format(encoded_data), # 👈 SET MIME_TYPE + DATA
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
** Equivalent GOOGLE API CALL **
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize a Gemini model appropriate for your use case.
|
||||||
|
model = genai.GenerativeModel('models/gemini-1.5-flash')
|
||||||
|
|
||||||
|
# Create the prompt.
|
||||||
|
prompt = "Please summarize the audio."
|
||||||
|
|
||||||
|
# Load the samplesmall.mp3 file into a Python Blob object containing the audio
|
||||||
|
# file's bytes and then pass the prompt and the audio to Gemini.
|
||||||
|
response = model.generate_content([
|
||||||
|
prompt,
|
||||||
|
{
|
||||||
|
"mime_type": "audio/mp3",
|
||||||
|
"data": pathlib.Path('samplesmall.mp3').read_bytes()
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
# Output Gemini's response to the prompt and the inline audio.
|
||||||
|
print(response.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### https:// file
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ""
|
||||||
|
|
||||||
|
litellm.set_verbose = True # 👈 See Raw call
|
||||||
|
|
||||||
|
model = "gemini/gemini-1.5-flash"
|
||||||
|
response = litellm.completion(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Please summarize the file."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "https://storage..." # 👈 SET THE IMG URL
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### gs:// file
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ""
|
||||||
|
|
||||||
|
litellm.set_verbose = True # 👈 See Raw call
|
||||||
|
|
||||||
|
model = "gemini/gemini-1.5-flash"
|
||||||
|
response = litellm.completion(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Please summarize the file."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "gs://..." # 👈 SET THE cloud storage bucket url
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Chat Models
|
## Chat Models
|
||||||
:::tip
|
:::tip
|
||||||
|
|
||||||
|
|
|
@ -1131,7 +1131,14 @@ def convert_to_gemini_tool_call_result(
|
||||||
"content": "function result goes here",
|
"content": "function result goes here",
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
content = message.get("content", "")
|
content_str: str = ""
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
content_str = message["content"]
|
||||||
|
elif isinstance(message["content"], List):
|
||||||
|
content_list = message["content"]
|
||||||
|
for content in content_list:
|
||||||
|
if content["type"] == "text":
|
||||||
|
content_str += content["text"]
|
||||||
name: Optional[str] = message.get("name", "") # type: ignore
|
name: Optional[str] = message.get("name", "") # type: ignore
|
||||||
|
|
||||||
# Recover name from last message with tool calls
|
# Recover name from last message with tool calls
|
||||||
|
@ -1156,10 +1163,10 @@ def convert_to_gemini_tool_call_result(
|
||||||
|
|
||||||
# We can't determine from openai message format whether it's a successful or
|
# We can't determine from openai message format whether it's a successful or
|
||||||
# error call result so default to the successful result template
|
# error call result so default to the successful result template
|
||||||
inferred_content_value = infer_protocol_value(value=content)
|
inferred_content_value = infer_protocol_value(value=content_str)
|
||||||
|
|
||||||
_field = litellm.types.llms.vertex_ai.Field(
|
_field = litellm.types.llms.vertex_ai.Field(
|
||||||
key="content", value={inferred_content_value: content}
|
key="content", value={inferred_content_value: content_str}
|
||||||
)
|
)
|
||||||
|
|
||||||
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
||||||
|
@ -1174,7 +1181,7 @@ def convert_to_gemini_tool_call_result(
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_tool_result(
|
def convert_to_anthropic_tool_result(
|
||||||
message: Union[dict, ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
||||||
) -> AnthropicMessagesToolResultParam:
|
) -> AnthropicMessagesToolResultParam:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
@ -1207,21 +1214,29 @@ def convert_to_anthropic_tool_result(
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
content_str: str = ""
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
content_str = message["content"]
|
||||||
|
elif isinstance(message["content"], List):
|
||||||
|
content_list = message["content"]
|
||||||
|
for content in content_list:
|
||||||
|
if content["type"] == "text":
|
||||||
|
content_str += content["text"]
|
||||||
if message["role"] == "tool":
|
if message["role"] == "tool":
|
||||||
tool_call_id: str = message.get("tool_call_id") # type: ignore
|
tool_message: ChatCompletionToolMessage = message
|
||||||
content: str = message.get("content") # type: ignore
|
tool_call_id: str = tool_message["tool_call_id"]
|
||||||
|
|
||||||
# We can't determine from openai message format whether it's a successful or
|
# We can't determine from openai message format whether it's a successful or
|
||||||
# error call result so default to the successful result template
|
# error call result so default to the successful result template
|
||||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||||
type="tool_result", tool_use_id=tool_call_id, content=content
|
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||||
)
|
)
|
||||||
return anthropic_tool_result
|
return anthropic_tool_result
|
||||||
if message["role"] == "function":
|
if message["role"] == "function":
|
||||||
content = message.get("content") # type: ignore
|
function_message: ChatCompletionFunctionMessage = message
|
||||||
tool_call_id = message.get("tool_call_id") or str(uuid.uuid4()) # type: ignore
|
tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
|
||||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||||
type="tool_result", tool_use_id=tool_call_id, content=content
|
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_tool_result
|
return anthropic_tool_result
|
||||||
|
@ -1624,7 +1639,8 @@ from litellm.types.llms.cohere import (
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_message_to_cohere_tool_result(
|
def convert_openai_message_to_cohere_tool_result(
|
||||||
message, tool_calls: List
|
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
|
||||||
|
tool_calls: List,
|
||||||
) -> ToolResultObject:
|
) -> ToolResultObject:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
@ -1660,7 +1676,15 @@ def convert_openai_message_to_cohere_tool_result(
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"""
|
"""
|
||||||
content_str: str = message.get("content", "")
|
|
||||||
|
content_str: str = ""
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
content_str = message["content"]
|
||||||
|
elif isinstance(message["content"], List):
|
||||||
|
content_list = message["content"]
|
||||||
|
for content in content_list:
|
||||||
|
if content["type"] == "text":
|
||||||
|
content_str += content["text"]
|
||||||
if len(content_str) > 0:
|
if len(content_str) > 0:
|
||||||
try:
|
try:
|
||||||
content = json.loads(content_str)
|
content = json.loads(content_str)
|
||||||
|
@ -1687,7 +1711,8 @@ def convert_openai_message_to_cohere_tool_result(
|
||||||
arguments = json.loads(arguments_str)
|
arguments = json.loads(arguments_str)
|
||||||
|
|
||||||
if message["role"] == "function":
|
if message["role"] == "function":
|
||||||
name = message.get("name")
|
function_message: ChatCompletionFunctionMessage = message
|
||||||
|
name = function_message["name"]
|
||||||
cohere_tool_result: ToolResultObject = {
|
cohere_tool_result: ToolResultObject = {
|
||||||
"call": CallObject(name=name, parameters=arguments),
|
"call": CallObject(name=name, parameters=arguments),
|
||||||
"outputs": [content],
|
"outputs": [content],
|
||||||
|
@ -2292,7 +2317,7 @@ def _convert_to_bedrock_tool_call_invoke(
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_bedrock_tool_call_result(
|
def _convert_to_bedrock_tool_call_result(
|
||||||
message: dict,
|
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage]
|
||||||
) -> BedrockContentBlock:
|
) -> BedrockContentBlock:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
|
@ -2334,11 +2359,18 @@ def _convert_to_bedrock_tool_call_result(
|
||||||
"""
|
"""
|
||||||
-
|
-
|
||||||
"""
|
"""
|
||||||
content = message.get("content", "")
|
content_str: str = ""
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
content_str = message["content"]
|
||||||
|
elif isinstance(message["content"], List):
|
||||||
|
content_list = message["content"]
|
||||||
|
for content in content_list:
|
||||||
|
if content["type"] == "text":
|
||||||
|
content_str += content["text"]
|
||||||
name = message.get("name", "")
|
name = message.get("name", "")
|
||||||
id = message.get("tool_call_id", str(uuid.uuid4()))
|
id = str(message.get("tool_call_id", str(uuid.uuid4())))
|
||||||
|
|
||||||
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
|
||||||
tool_result = BedrockToolResultBlock(
|
tool_result = BedrockToolResultBlock(
|
||||||
content=[tool_result_content_block],
|
content=[tool_result_content_block],
|
||||||
toolUseId=id,
|
toolUseId=id,
|
||||||
|
|
|
@ -160,7 +160,11 @@ def _gemini_convert_messages_with_history(
|
||||||
_part = PartType(text=element["text"]) # type: ignore
|
_part = PartType(text=element["text"]) # type: ignore
|
||||||
_parts.append(_part)
|
_parts.append(_part)
|
||||||
elif element["type"] == "image_url":
|
elif element["type"] == "image_url":
|
||||||
image_url = element["image_url"]["url"] # type: ignore
|
img_element: ChatCompletionImageObject = element # type: ignore
|
||||||
|
if isinstance(img_element["image_url"], dict):
|
||||||
|
image_url = img_element["image_url"]["url"]
|
||||||
|
else:
|
||||||
|
image_url = img_element["image_url"]
|
||||||
_part = _process_gemini_image(image_url=image_url)
|
_part = _process_gemini_image(image_url=image_url)
|
||||||
_parts.append(_part) # type: ignore
|
_parts.append(_part) # type: ignore
|
||||||
user_content.extend(_parts)
|
user_content.extend(_parts)
|
||||||
|
|
|
@ -496,7 +496,7 @@ async def get_team_object(
|
||||||
|
|
||||||
if check_cache_only:
|
if check_cache_only:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# else, check db
|
# else, check db
|
||||||
|
|
|
@ -568,7 +568,9 @@ async def user_api_key_auth(
|
||||||
if field_name in valid_token.__fields__:
|
if field_name in valid_token.__fields__:
|
||||||
setattr(valid_token, field_name, v)
|
setattr(valid_token, field_name, v)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.warning(e)
|
verbose_logger.debug(
|
||||||
|
e
|
||||||
|
) # moving from .warning to .debug as it spams logs when team missing from cache.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
||||||
|
|
|
@ -3021,7 +3021,7 @@ async def startup_event():
|
||||||
@router.get(
|
@router.get(
|
||||||
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
||||||
) # if project requires model list
|
) # if project requires model list
|
||||||
def model_list(
|
async def model_list(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -25,9 +25,6 @@ class RerankResponse(BaseModel):
|
||||||
meta: dict # Contains api_version and billed_units
|
meta: dict # Contains api_version and billed_units
|
||||||
_hidden_params: dict = {}
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
class Config:
|
|
||||||
underscore_attrs_are_private = True
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.__dict__[key]
|
return self.__dict__[key]
|
||||||
|
|
||||||
|
|
|
@ -4648,13 +4648,12 @@ class Router:
|
||||||
|
|
||||||
Used for accurate 'get_model_list'.
|
Used for accurate 'get_model_list'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if model["model_name"] == model_name:
|
if model["model_name"] == model_name:
|
||||||
if model_alias is not None:
|
if model_alias is not None:
|
||||||
alias_model = copy.deepcopy(model)
|
alias_model = copy.deepcopy(model)
|
||||||
alias_model["model_name"] = model_name
|
alias_model["model_name"] = model_alias
|
||||||
returned_models.append(alias_model)
|
returned_models.append(alias_model)
|
||||||
else:
|
else:
|
||||||
returned_models.append(model)
|
returned_models.append(model)
|
||||||
|
|
|
@ -5,6 +5,8 @@ import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.types
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
@ -1232,3 +1234,56 @@ def test_bedrock_cross_region_inference():
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_converse_translation_tool_message():
|
||||||
|
from litellm.types.utils import ChatCompletionMessageToolCall, Function
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_call_id": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
|
||||||
|
"type": "text",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
translated_msg = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages, model="", llm_provider=""
|
||||||
|
)
|
||||||
|
|
||||||
|
print(translated_msg)
|
||||||
|
assert translated_msg == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"toolResult": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"toolUseId": "tooluse_DnqEmD5qR6y2-aJ-Xd05xw",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
|
@ -48,8 +48,8 @@ def get_current_weather(location, unit="fahrenheit"):
|
||||||
# "gpt-3.5-turbo-1106",
|
# "gpt-3.5-turbo-1106",
|
||||||
# "mistral/mistral-large-latest",
|
# "mistral/mistral-large-latest",
|
||||||
# "claude-3-haiku-20240307",
|
# "claude-3-haiku-20240307",
|
||||||
"gemini/gemini-1.5-pro",
|
# "gemini/gemini-1.5-pro",
|
||||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
|
|
@ -1432,3 +1432,72 @@ async def test_gemini_pass_through_endpoint():
|
||||||
)
|
)
|
||||||
|
|
||||||
print(resp.body)
|
print(resp.body)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proxy_model_group_alias_checks(prisma_client):
|
||||||
|
"""
|
||||||
|
Check if model group alias is returned on
|
||||||
|
|
||||||
|
`/v1/models`
|
||||||
|
`/v1/model/info`
|
||||||
|
`/v1/model_group/info`
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||||
|
|
||||||
|
_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
model_alias = "gpt-4"
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=_model_list,
|
||||||
|
model_group_alias={model_alias: "gpt-3.5-turbo"},
|
||||||
|
)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||||
|
request._url = URL(url="/v1/models")
|
||||||
|
|
||||||
|
resp = await model_list(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(resp) == 2
|
||||||
|
print(resp)
|
||||||
|
|
||||||
|
resp = await model_info_v1(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
models = resp["data"]
|
||||||
|
is_model_alias_in_list = False
|
||||||
|
for item in models:
|
||||||
|
if model_alias == item["model_name"]:
|
||||||
|
is_model_alias_in_list = True
|
||||||
|
|
||||||
|
assert is_model_alias_in_list
|
||||||
|
|
||||||
|
resp = await model_group_info(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
models = resp["data"]
|
||||||
|
is_model_alias_in_list = False
|
||||||
|
for item in models:
|
||||||
|
if model_alias == item.model_group:
|
||||||
|
is_model_alias_in_list = True
|
||||||
|
|
||||||
|
assert is_model_alias_in_list
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import List, Optional, Union, Iterable
|
from typing import Iterable, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, validator
|
from pydantic import BaseModel, ConfigDict, validator
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +93,7 @@ class Function(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolMessageParam(TypedDict, total=False):
|
class ChatCompletionToolMessageParam(TypedDict, total=False):
|
||||||
content: Required[str]
|
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
|
||||||
"""The contents of the tool message."""
|
"""The contents of the tool message."""
|
||||||
|
|
||||||
role: Required[Literal["tool"]]
|
role: Required[Literal["tool"]]
|
||||||
|
@ -105,7 +104,7 @@ class ChatCompletionToolMessageParam(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionFunctionMessageParam(TypedDict, total=False):
|
class ChatCompletionFunctionMessageParam(TypedDict, total=False):
|
||||||
content: Required[Optional[str]]
|
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
|
||||||
"""The contents of the function message."""
|
"""The contents of the function message."""
|
||||||
|
|
||||||
name: Required[str]
|
name: Required[str]
|
||||||
|
|
|
@ -340,7 +340,7 @@ class ChatCompletionImageUrlObject(TypedDict, total=False):
|
||||||
|
|
||||||
class ChatCompletionImageObject(TypedDict):
|
class ChatCompletionImageObject(TypedDict):
|
||||||
type: Literal["image_url"]
|
type: Literal["image_url"]
|
||||||
image_url: ChatCompletionImageUrlObject
|
image_url: Union[str, ChatCompletionImageUrlObject]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionUserMessage(TypedDict):
|
class OpenAIChatCompletionUserMessage(TypedDict):
|
||||||
|
@ -368,14 +368,15 @@ class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total
|
||||||
|
|
||||||
class ChatCompletionToolMessage(TypedDict):
|
class ChatCompletionToolMessage(TypedDict):
|
||||||
role: Literal["tool"]
|
role: Literal["tool"]
|
||||||
content: str
|
content: Union[str, Iterable[ChatCompletionTextObject]]
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionFunctionMessage(TypedDict):
|
class ChatCompletionFunctionMessage(TypedDict):
|
||||||
role: Literal["function"]
|
role: Literal["function"]
|
||||||
content: Optional[str]
|
content: Optional[Union[str, Iterable[ChatCompletionTextObject]]]
|
||||||
name: str
|
name: str
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
|
class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue