[Feat] Allow setting supports_vision for Custom OpenAI endpoints + Added testing (#5821)

* add test for using images with custom openai endpoints

* run all otel tests

* update name of test

* add custom openai model to test config

* add test for setting supports_vision=True for model

* fix test guardrails aporia

* docs supports vison

* fix yaml

* fix yaml

* docs supports vision

* fix bedrock guardrail test

* fix cohere rerank test

* update model_group doc string

* add better prints on test
This commit is contained in:
Ishaan Jaff 2024-09-21 11:35:55 -07:00 committed by GitHub
parent 4069942dd8
commit 1973ae8fb8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 477 additions and 39 deletions

View file

@ -404,7 +404,7 @@ jobs:
# Store test results # Store test results
- store_test_results: - store_test_results:
path: test-results path: test-results
proxy_log_to_otel_tests: proxy_logging_guardrails_model_info_tests:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1
resource_class: xlarge resource_class: xlarge
@ -476,6 +476,7 @@ jobs:
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
-e COHERE_API_KEY=$COHERE_API_KEY \
--name my-app \ --name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \ -v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
@ -503,7 +504,7 @@ jobs:
command: | command: |
pwd pwd
ls ls
python -m pytest -vv tests/otel_tests/test_otel.py -x --junitxml=test-results/junit.xml --durations=5 python -m pytest -vv tests/otel_tests -x --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m no_output_timeout: 120m
# Store test results # Store test results
@ -711,7 +712,7 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- proxy_log_to_otel_tests: - proxy_logging_guardrails_model_info_tests:
filters: filters:
branches: branches:
only: only:
@ -751,7 +752,7 @@ workflows:
- litellm_assistants_api_testing - litellm_assistants_api_testing
- ui_endpoint_testing - ui_endpoint_testing
- installing_litellm_on_python - installing_litellm_on_python
- proxy_log_to_otel_tests - proxy_logging_guardrails_model_info_tests
- proxy_pass_through_endpoint_tests - proxy_pass_through_endpoint_tests
filters: filters:
branches: branches:

View file

@ -1,8 +1,16 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Using Vision Models # Using Vision Models
## Quick Start ## Quick Start
Example passing images to a model Example passing images to a model
<Tabs>
<TabItem label="LiteLLMPython SDK" value="Python">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -33,8 +41,80 @@ response = completion(
``` ```
</TabItem>
<TabItem label="LiteLLM Proxy Server" value="proxy">
1. Define vision models on config.yaml
```yaml
model_list:
- model_name: gpt-4-vision-preview # OpenAI gpt-4-vision-preview
litellm_params:
model: openai/gpt-4-vision-preview
api_key: os.environ/OPENAI_API_KEY
- model_name: llava-hf # Custom OpenAI compatible model
litellm_params:
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
api_base: http://localhost:8000
api_key: fake-key
model_info:
supports_vision: True # set supports_vision to True so /model/info returns this attribute as True
```
2. Run proxy server
```bash
litellm --config config.yaml
```
3. Test it using the OpenAI Python SDK
```python
import os
from openai import OpenAI
client = OpenAI(
api_key="sk-1234", # your litellm proxy api key
)
response = client.chat.completions.create(
model = "gpt-4-vision-preview", # use model="llava-hf" to test your custom OpenAI endpoint
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
)
```
</TabItem>
</Tabs>
## Checking if a model supports `vision` ## Checking if a model supports `vision`
<Tabs>
<TabItem label="LiteLLM Python SDK" value="Python">
Use `litellm.supports_vision(model="")` -> returns `True` if model supports `vision` and `False` if not Use `litellm.supports_vision(model="")` -> returns `True` if model supports `vision` and `False` if not
```python ```python
@ -42,4 +122,69 @@ assert litellm.supports_vision(model="gpt-4-vision-preview") == True
assert litellm.supports_vision(model="gemini-1.0-pro-vision") == True assert litellm.supports_vision(model="gemini-1.0-pro-vision") == True
assert litellm.supports_vision(model="gpt-3.5-turbo") == False assert litellm.supports_vision(model="gpt-3.5-turbo") == False
``` ```
</TabItem>
<TabItem label="LiteLLM Proxy Server" value="proxy">
1. Define vision models on config.yaml
```yaml
model_list:
- model_name: gpt-4-vision-preview # OpenAI gpt-4-vision-preview
litellm_params:
model: openai/gpt-4-vision-preview
api_key: os.environ/OPENAI_API_KEY
- model_name: llava-hf # Custom OpenAI compatible model
litellm_params:
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
api_base: http://localhost:8000
api_key: fake-key
model_info:
supports_vision: True # set supports_vision to True so /model/info returns this attribute as True
```
2. Run proxy server
```bash
litellm --config config.yaml
```
3. Call `/model_group/info` to check if your model supports `vision`
```shell
curl -X 'GET' \
'http://localhost:4000/model_group/info' \
-H 'accept: application/json' \
-H 'x-api-key: sk-1234'
```
Expected Response
```json
{
"data": [
{
"model_group": "gpt-4-vision-preview",
"providers": ["openai"],
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"mode": "chat",
"supports_vision": true, # 👈 supports_vision is true
"supports_function_calling": false
},
{
"model_group": "llava-hf",
"providers": ["openai"],
"max_input_tokens": null,
"max_output_tokens": null,
"mode": null,
"supports_vision": true, # 👈 supports_vision is true
"supports_function_calling": false
}
]
}
```
</TabItem>
</Tabs>

View file

@ -1236,7 +1236,7 @@
}, },
"deepseek-chat": { "deepseek-chat": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 32000, "max_input_tokens": 128000,
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.00000014, "input_cost_per_token": 0.00000014,
"input_cost_per_token_cache_hit": 0.000000014, "input_cost_per_token_cache_hit": 0.000000014,

View file

@ -15,10 +15,17 @@ model_list:
tags: ["teamB"] tags: ["teamB"]
model_info: model_info:
id: "team-b-model" id: "team-b-model"
- model_name: rerank-english-v3.0 # Fixed indentation here - model_name: rerank-english-v3.0
litellm_params: litellm_params:
model: cohere/rerank-english-v3.0 model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY api_key: os.environ/COHERE_API_KEY
- model_name: llava-hf
litellm_params:
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
api_base: http://localhost:8000
api_key: fake-key
model_info:
supports_vision: True
litellm_settings: litellm_settings:
@ -41,7 +48,7 @@ guardrails:
- guardrail_name: "bedrock-pre-guard" - guardrail_name: "bedrock-pre-guard"
litellm_params: litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "pre_call" mode: "during_call"
guardrailIdentifier: ff6ujrregl1q guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT" guardrailVersion: "DRAFT"
- guardrail_name: "custom-pre-guard" - guardrail_name: "custom-pre-guard"
@ -56,3 +63,6 @@ guardrails:
litellm_params: litellm_params:
guardrail: custom_guardrail.myCustomGuardrail guardrail: custom_guardrail.myCustomGuardrail
mode: "post_call" mode: "post_call"
router_settings:
enable_tag_filtering: True # 👈 Key Change

View file

@ -1,32 +1,57 @@
model_list: model_list:
- model_name: gemini-vision - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: vertex_ai/gemini-1.5-pro model: openai/gpt-3.5-turbo
api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 api_key: fake-key
vertex_project: "adroit-crow-413218" api_base: https://exampleopenaiendpoint-production.up.railway.app/
vertex_location: "us-central1" tags: ["teamB"]
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" model_info:
- model_name: gemini-vision id: "team-b-model"
litellm_params: - model_name: rerank-english-v3.0
model: vertex_ai/gemini-1.0-pro-vision-001 litellm_params:
api_base: https://exampleopenaiendpoint-production-c715.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 model: cohere/rerank-english-v3.0
vertex_project: "adroit-crow-413218" api_key: os.environ/COHERE_API_KEY
vertex_location: "us-central1" - model_name: llava-hf
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" litellm_params:
model: openai/llava-hf/llava-v1.6-vicuna-7b-hf
api_base: http://localhost:8000
api_key: fake-key
model_info:
supports_vision: True
- model_name: fake-azure-endpoint
litellm_params:
model: openai/429
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app
general_settings:
master_key: sk-1234
default_team_disabled: true
custom_sso: custom_sso.custom_sso_handler
litellm_settings: litellm_settings:
success_callback: ["prometheus"] cache: true
# callbacks: ["otel"]
guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_1
api_base: os.environ/APORIA_API_BASE_1
- guardrail_name: "aporia-post-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
- guardrail_name: "custom-pre-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
mode: "pre_call"
- guardrail_name: "custom-during-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
mode: "during_call"
- guardrail_name: "custom-post-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
mode: "post_call"

View file

@ -7595,7 +7595,6 @@ async def model_info_v1(
@router.get( @router.get(
"/model_group/info", "/model_group/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"], tags=["model management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
@ -7603,7 +7602,134 @@ async def model_group_info(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
Returns model info at the model group level. Get information about all the deployments on litellm proxy, including config.yaml descriptions (except api key and api base)
- /models returns all deployments. Proxy Admins can use this to list all deployments setup on the proxy
- /model_group/info returns all model groups. End users of proxy should use /model_group/info since those models will be used for /chat/completions, /embeddings, etc.
```shell
curl -X 'GET' \
'http://localhost:4000/model_group/info' \
-H 'accept: application/json' \
-H 'x-api-key: sk-1234'
```
Example Response:
```json
{
"data": [
{
"model_group": "rerank-english-v3.0",
"providers": [
"cohere"
],
"max_input_tokens": null,
"max_output_tokens": null,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"mode": null,
"tpm": null,
"rpm": null,
"supports_parallel_function_calling": false,
"supports_vision": false,
"supports_function_calling": false,
"supported_openai_params": [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"extra_headers"
]
},
{
"model_group": "gpt-3.5-turbo",
"providers": [
"openai"
],
"max_input_tokens": 16385.0,
"max_output_tokens": 4096.0,
"input_cost_per_token": 1.5e-06,
"output_cost_per_token": 2e-06,
"mode": "chat",
"tpm": null,
"rpm": null,
"supports_parallel_function_calling": false,
"supports_vision": false,
"supports_function_calling": true,
"supported_openai_params": [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"response_format"
]
},
{
"model_group": "llava-hf",
"providers": [
"openai"
],
"max_input_tokens": null,
"max_output_tokens": null,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"mode": null,
"tpm": null,
"rpm": null,
"supports_parallel_function_calling": false,
"supports_vision": true,
"supports_function_calling": false,
"supported_openai_params": [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"response_format"
]
}
]
}
```
""" """
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router

View file

@ -0,0 +1,94 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio()
@pytest.mark.respx
async def test_vision_with_custom_model(respx_mock: MockRouter):
"""
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
"""
import base64
import requests
litellm.set_verbose = True
api_base = "https://my-custom.api.openai.com"
# Fetch and encode a test image
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="my-custom-model",
)
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
return_value=httpx.Response(200, json=mock_response.dict())
)
response = await litellm.acompletion(
model="openai/my-custom-model",
max_tokens=10,
api_base=api_base, # use the mock api
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
],
}
],
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body)
assert request_body == {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": ""
},
},
],
}
],
"model": "my-custom-model",
"max_tokens": 10,
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)

View file

@ -70,6 +70,7 @@ async def generate_key(session, guardrails):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_llm_guard_triggered_safe_request(): async def test_llm_guard_triggered_safe_request():
""" """
- Tests a request where no content mod is triggered - Tests a request where no content mod is triggered
@ -99,6 +100,7 @@ async def test_llm_guard_triggered_safe_request():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_llm_guard_triggered(): async def test_llm_guard_triggered():
""" """
- Tests a request where no content mod is triggered - Tests a request where no content mod is triggered
@ -146,6 +148,7 @@ async def test_no_llm_guard_triggered():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_guardrails_with_api_key_controls(): async def test_guardrails_with_api_key_controls():
""" """
- Make two API Keys - Make two API Keys

View file

@ -0,0 +1,28 @@
"""
/model/info test
"""
import httpx
import pytest
@pytest.mark.asyncio()
async def test_custom_model_supports_vision():
async with httpx.AsyncClient() as client:
response = await client.get(
"http://localhost:4000/model/info",
headers={"Authorization": "Bearer sk-1234"},
)
assert response.status_code == 200
data = response.json()["data"]
print("response from /model/info", data)
llava_model = next(
(model for model in data if model["model_name"] == "llava-hf"), None
)
assert llava_model is not None, "llava-hf model not found in response"
assert (
llava_model["model_info"]["supports_vision"] == True
), "llava-hf model should support vision"

View file

@ -18,6 +18,7 @@ async def chat_completion(
"Authorization": f"Bearer {key}", "Authorization": f"Bearer {key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
print("headers=", headers)
data = { data = {
"model": model, "model": model,
"messages": [ "messages": [
@ -96,16 +97,21 @@ async def test_team_tag_routing():
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
key = LITELLM_MASTER_KEY key = LITELLM_MASTER_KEY
team_a_data = await create_team_with_tags(session, key, ["teamA"]) team_a_data = await create_team_with_tags(session, key, ["teamA"])
print("team_a_data=", team_a_data)
team_a_id = team_a_data["team_id"] team_a_id = team_a_data["team_id"]
team_b_data = await create_team_with_tags(session, key, ["teamB"]) team_b_data = await create_team_with_tags(session, key, ["teamB"])
print("team_b_data=", team_b_data)
team_b_id = team_b_data["team_id"] team_b_id = team_b_data["team_id"]
key_with_team_a = await create_key_with_team(session, key, team_a_id) key_with_team_a = await create_key_with_team(session, key, team_a_id)
print(key_with_team_a) print("key_with_team_a=", key_with_team_a)
_key_with_team_a = key_with_team_a["key"] _key_with_team_a = key_with_team_a["key"]
for _ in range(5): for _ in range(5):
response_a, headers = await chat_completion(session, _key_with_team_a) response_a, headers = await chat_completion(
session=session, key=_key_with_team_a
)
headers = dict(headers) headers = dict(headers)
print(response_a) print(response_a)
print(headers) print(headers)