(Bug Fix) Using LiteLLM Python SDK with model=litellm_proxy/ for embedding, image_generation, transcription, speech, rerank (#8815)

* test_litellm_gateway_from_sdk

* fix embedding check for openai

* test litellm proxy provider

* fix image generation openai compatible models

* fix litellm.transcription

* test_litellm_gateway_from_sdk_rerank

* docs litellm python sdk

* docs litellm python sdk with proxy

* test_litellm_gateway_from_sdk_rerank

* ci/cd run again

* test_litellm_gateway_from_sdk_image_generation

* test_litellm_gateway_from_sdk_embedding

* test_litellm_gateway_from_sdk_embedding
This commit is contained in:
Ishaan Jaff 2025-02-25 16:22:37 -08:00 committed by GitHub
parent ef22209a15
commit f9cee4c46b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 466 additions and 83 deletions

View file

@ -3,13 +3,15 @@ import TabItem from '@theme/TabItem';
# LiteLLM Proxy (LLM Gateway) # LiteLLM Proxy (LLM Gateway)
:::tip
[LiteLLM Providers a **self hosted** proxy server (AI Gateway)](../simple_proxy) to call all the LLMs in the OpenAI format | Property | Details |
|-------|-------|
| Description | LiteLLM Proxy is an OpenAI-compatible gateway that allows you to interact with multiple LLM providers through a unified API. Simply use the `litellm_proxy/` prefix before the model name to route your requests through the proxy. |
| Provider Route on LiteLLM | `litellm_proxy/` (add this prefix to the model name, to route any requests to litellm_proxy - e.g. `litellm_proxy/your-model-name`) |
| Setup LiteLLM Gateway | [LiteLLM Gateway ↗](../simple_proxy) |
| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/rerank` |
:::
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
## Required Variables ## Required Variables
@ -83,7 +85,76 @@ for chunk in response:
print(chunk) print(chunk)
``` ```
## Embeddings
```python
import litellm
response = litellm.embedding(
model="litellm_proxy/your-embedding-model",
input="Hello world",
api_base="your-litellm-proxy-url",
api_key="your-litellm-proxy-api-key"
)
```
## Image Generation
```python
import litellm
response = litellm.image_generation(
model="litellm_proxy/dall-e-3",
prompt="A beautiful sunset over mountains",
api_base="your-litellm-proxy-url",
api_key="your-litellm-proxy-api-key"
)
```
## Audio Transcription
```python
import litellm
response = litellm.transcription(
model="litellm_proxy/whisper-1",
file="your-audio-file",
api_base="your-litellm-proxy-url",
api_key="your-litellm-proxy-api-key"
)
```
## Text to Speech
```python
import litellm
response = litellm.speech(
model="litellm_proxy/tts-1",
input="Hello world",
api_base="your-litellm-proxy-url",
api_key="your-litellm-proxy-api-key"
)
```
## Rerank
```python
import litellm
import litellm
response = litellm.rerank(
model="litellm_proxy/rerank-english-v2.0",
query="What is machine learning?",
documents=[
"Machine learning is a field of study in artificial intelligence",
"Biology is the study of living organisms"
],
api_base="your-litellm-proxy-url",
api_key="your-litellm-proxy-api-key"
)
```
## **Usage with Langchain, LLamaindex, OpenAI Js, Anthropic SDK, Instructor** ## **Usage with Langchain, LLamaindex, OpenAI Js, Anthropic SDK, Instructor**
#### [Follow this doc to see how to use litellm proxy with langchain, llamaindex, anthropic etc](../proxy/user_keys) #### [Follow this doc to see how to use litellm proxy with langchain, llamaindex, anthropic etc](../proxy/user_keys)

View file

@ -112,6 +112,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client,
) )
## LOGGING ## LOGGING

View file

@ -3409,6 +3409,7 @@ def embedding( # noqa: PLR0915
or custom_llm_provider == "openai" or custom_llm_provider == "openai"
or custom_llm_provider == "together_ai" or custom_llm_provider == "together_ai"
or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "litellm_proxy"
): ):
api_base = ( api_base = (
api_base api_base
@ -3485,7 +3486,8 @@ def embedding( # noqa: PLR0915
# set API KEY # set API KEY
if api_key is None: if api_key is None:
api_key = ( api_key = (
litellm.api_key api_key
or litellm.api_key
or litellm.openai_like_key or litellm.openai_like_key
or get_secret_str("OPENAI_LIKE_API_KEY") or get_secret_str("OPENAI_LIKE_API_KEY")
) )
@ -4596,7 +4598,10 @@ def image_generation( # noqa: PLR0915
client=client, client=client,
headers=headers, headers=headers,
) )
elif custom_llm_provider == "openai": elif (
custom_llm_provider == "openai"
or custom_llm_provider in litellm.openai_compatible_providers
):
model_response = openai_chat_completions.image_generation( model_response = openai_chat_completions.image_generation(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -5042,8 +5047,7 @@ def transcription(
) )
elif ( elif (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "groq" or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "fireworks_ai"
): ):
api_base = ( api_base = (
api_base api_base
@ -5201,7 +5205,10 @@ def speech(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
response: Optional[HttpxBinaryResponseContent] = None response: Optional[HttpxBinaryResponseContent] = None
if custom_llm_provider == "openai": if (
custom_llm_provider == "openai"
or custom_llm_provider in litellm.openai_compatible_providers
):
if voice is None or not (isinstance(voice, str)): if voice is None or not (isinstance(voice, str)):
raise litellm.BadRequestError( raise litellm.BadRequestError(
message="'voice' is required to be passed as a string for OpenAI TTS", message="'voice' is required to be passed as a string for OpenAI TTS",

View file

@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[ custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity"] Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
] = None, ] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
@ -162,7 +162,7 @@ def rerank( # noqa: PLR0915
) )
# Implement rerank logic here based on the custom_llm_provider # Implement rerank logic here based on the custom_llm_provider
if _custom_llm_provider == "cohere": if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy":
# Implement Cohere rerank logic # Implement Cohere rerank logic
api_key: Optional[str] = ( api_key: Optional[str] = (
dynamic_api_key or optional_params.api_key or litellm.api_key dynamic_api_key or optional_params.api_key or litellm.api_key

View file

@ -0,0 +1,376 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
from litellm import completion, embedding
import pytest
from unittest.mock import MagicMock, patch
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
import pytest_asyncio
from openai import AsyncOpenAI
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk():
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hello world",
}
]
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
completion(
model="litellm_proxy/my-vllm-model",
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
assert "hello" in mock_call.call_args.kwargs["extra_body"]
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_structured_output():
from pydantic import BaseModel
class Result(BaseModel):
answer: str
litellm.set_verbose = True
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
litellm.completion(
model="litellm_proxy/openai/gpt-4o",
messages=[
{"role": "user", "content": "What is the capital of France?"}
],
api_key="my-test-api-key",
user="test",
response_format=Result,
base_url="https://litellm.ml-serving-internal.scale.com",
client=openai_client,
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
json_schema = mock_call.call_args.kwargs["response_format"]
assert "json_schema" in json_schema
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_embedding(is_async):
litellm.set_verbose = True
litellm._turn_on_debug()
if is_async:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI(api_key="fake-key")
mock_method = AsyncMock()
patch_target = openai_client.embeddings.create
else:
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
mock_method = MagicMock()
patch_target = openai_client.embeddings.create
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
try:
if is_async:
await litellm.aembedding(
model="litellm_proxy/my-vllm-model",
input="Hello world",
client=openai_client,
api_base="my-custom-api-base",
)
else:
litellm.embedding(
model="litellm_proxy/my-vllm-model",
input="Hello world",
client=openai_client,
api_base="my-custom-api-base",
)
except Exception as e:
print(e)
mock_method.assert_called_once()
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
assert "Hello world" == mock_method.call_args.kwargs["input"]
assert "my-vllm-model" == mock_method.call_args.kwargs["model"]
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_image_generation(is_async):
litellm._turn_on_debug()
if is_async:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI(api_key="fake-key")
mock_method = AsyncMock()
patch_target = openai_client.images.generate
else:
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
mock_method = MagicMock()
patch_target = openai_client.images.generate
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
try:
if is_async:
response = await litellm.aimage_generation(
model="litellm_proxy/dall-e-3",
prompt="A beautiful sunset over mountains",
client=openai_client,
api_base="my-custom-api-base",
)
else:
response = litellm.image_generation(
model="litellm_proxy/dall-e-3",
prompt="A beautiful sunset over mountains",
client=openai_client,
api_base="my-custom-api-base",
)
print("response=", response)
except Exception as e:
print("got error", e)
mock_method.assert_called_once()
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
assert (
"A beautiful sunset over mountains"
== mock_method.call_args.kwargs["prompt"]
)
assert "dall-e-3" == mock_method.call_args.kwargs["model"]
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_transcription(is_async):
litellm.set_verbose = True
litellm._turn_on_debug()
if is_async:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI(api_key="fake-key")
mock_method = AsyncMock()
patch_target = openai_client.audio.transcriptions.create
else:
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
mock_method = MagicMock()
patch_target = openai_client.audio.transcriptions.create
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
try:
if is_async:
await litellm.atranscription(
model="litellm_proxy/whisper-1",
file=b"sample_audio",
client=openai_client,
api_base="my-custom-api-base",
)
else:
litellm.transcription(
model="litellm_proxy/whisper-1",
file=b"sample_audio",
client=openai_client,
api_base="my-custom-api-base",
)
except Exception as e:
print(e)
mock_method.assert_called_once()
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
assert "whisper-1" == mock_method.call_args.kwargs["model"]
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_speech(is_async):
litellm.set_verbose = True
if is_async:
from openai import AsyncOpenAI
openai_client = AsyncOpenAI(api_key="fake-key")
mock_method = AsyncMock()
patch_target = openai_client.audio.speech.create
else:
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
mock_method = MagicMock()
patch_target = openai_client.audio.speech.create
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
try:
if is_async:
await litellm.aspeech(
model="litellm_proxy/tts-1",
input="Hello, this is a test of text to speech",
voice="alloy",
client=openai_client,
api_base="my-custom-api-base",
)
else:
litellm.speech(
model="litellm_proxy/tts-1",
input="Hello, this is a test of text to speech",
voice="alloy",
client=openai_client,
api_base="my-custom-api-base",
)
except Exception as e:
print(e)
mock_method.assert_called_once()
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
assert (
"Hello, this is a test of text to speech"
== mock_method.call_args.kwargs["input"]
)
assert "tts-1" == mock_method.call_args.kwargs["model"]
assert "alloy" == mock_method.call_args.kwargs["voice"]
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_rerank(is_async):
litellm.set_verbose = True
litellm._turn_on_debug()
if is_async:
client = AsyncHTTPHandler()
mock_method = AsyncMock()
patch_target = client.post
else:
client = HTTPHandler()
mock_method = MagicMock()
patch_target = client.post
with patch.object(client, "post", new=mock_method):
mock_response = MagicMock()
# Create a mock response similar to OpenAI's rerank response
mock_response.text = json.dumps(
{
"id": "rerank-123456",
"object": "reranking",
"results": [
{
"index": 0,
"relevance_score": 0.9,
"document": {
"id": "0",
"text": "Machine learning is a field of study in artificial intelligence",
},
},
{
"index": 1,
"relevance_score": 0.2,
"document": {
"id": "1",
"text": "Biology is the study of living organisms",
},
},
],
"model": "rerank-english-v2.0",
"usage": {"prompt_tokens": 10, "total_tokens": 10},
}
)
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
if is_async:
mock_method.return_value = mock_response
else:
mock_method.return_value = mock_response
try:
if is_async:
response = await litellm.arerank(
model="litellm_proxy/rerank-english-v2.0",
query="What is machine learning?",
documents=[
"Machine learning is a field of study in artificial intelligence",
"Biology is the study of living organisms",
],
client=client,
api_base="my-custom-api-base",
)
else:
response = litellm.rerank(
model="litellm_proxy/rerank-english-v2.0",
query="What is machine learning?",
documents=[
"Machine learning is a field of study in artificial intelligence",
"Biology is the study of living organisms",
],
client=client,
api_base="my-custom-api-base",
)
except Exception as e:
print(e)
# Verify the request
mock_method.assert_called_once()
call_args = mock_method.call_args
print("call_args=", call_args)
# Check that the URL is correct
assert "my-custom-api-base/v1/rerank" == call_args.kwargs["url"]
# Check that the request body contains the expected data
request_body = json.loads(call_args.kwargs["data"])
assert request_body["query"] == "What is machine learning?"
assert request_body["model"] == "rerank-english-v2.0"
assert len(request_body["documents"]) == 2

View file

@ -1819,78 +1819,6 @@ def test_lm_studio_completion(monkeypatch):
print(e) print(e)
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk():
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hello world",
}
]
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
completion(
model="litellm_proxy/my-vllm-model",
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
assert "hello" in mock_call.call_args.kwargs["extra_body"]
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk_structured_output():
from pydantic import BaseModel
class Result(BaseModel):
answer: str
litellm.set_verbose = True
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
litellm.completion(
model="litellm_proxy/openai/gpt-4o",
messages=[
{"role": "user", "content": "What is the capital of France?"}
],
api_key="my-test-api-key",
user="test",
response_format=Result,
base_url="https://litellm.ml-serving-internal.scale.com",
client=openai_client,
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
json_schema = mock_call.call_args.kwargs["response_format"]
assert "json_schema" in json_schema
# ################### Hugging Face Conversational models ######################## # ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv(): # def hf_test_completion_conv():
# try: # try: