diff --git a/docs/my-website/docs/providers/litellm_proxy.md b/docs/my-website/docs/providers/litellm_proxy.md index 69377b27f1..e204caba0a 100644 --- a/docs/my-website/docs/providers/litellm_proxy.md +++ b/docs/my-website/docs/providers/litellm_proxy.md @@ -3,13 +3,15 @@ import TabItem from '@theme/TabItem'; # LiteLLM Proxy (LLM Gateway) -:::tip -[LiteLLM Providers a **self hosted** proxy server (AI Gateway)](../simple_proxy) to call all the LLMs in the OpenAI format +| Property | Details | +|-------|-------| +| Description | LiteLLM Proxy is an OpenAI-compatible gateway that allows you to interact with multiple LLM providers through a unified API. Simply use the `litellm_proxy/` prefix before the model name to route your requests through the proxy. | +| Provider Route on LiteLLM | `litellm_proxy/` (add this prefix to the model name, to route any requests to litellm_proxy - e.g. `litellm_proxy/your-model-name`) | +| Setup LiteLLM Gateway | [LiteLLM Gateway ↗](../simple_proxy) | +| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/rerank` | -::: -**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model ## Required Variables @@ -83,7 +85,76 @@ for chunk in response: print(chunk) ``` +## Embeddings +```python +import litellm + +response = litellm.embedding( + model="litellm_proxy/your-embedding-model", + input="Hello world", + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key" +) +``` + +## Image Generation + +```python +import litellm + +response = litellm.image_generation( + model="litellm_proxy/dall-e-3", + prompt="A beautiful sunset over mountains", + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key" +) +``` + +## Audio Transcription + +```python +import litellm + +response = litellm.transcription( + model="litellm_proxy/whisper-1", + file="your-audio-file", + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key" +) +``` + +## Text to Speech + +```python +import litellm + +response = litellm.speech( + model="litellm_proxy/tts-1", + input="Hello world", + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key" +) +``` + +## Rerank + +```python +import litellm + +import litellm + +response = litellm.rerank( + model="litellm_proxy/rerank-english-v2.0", + query="What is machine learning?", + documents=[ + "Machine learning is a field of study in artificial intelligence", + "Biology is the study of living organisms" + ], + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key" +) +``` ## **Usage with Langchain, LLamaindex, OpenAI Js, Anthropic SDK, Instructor** #### [Follow this doc to see how to use litellm proxy with langchain, llamaindex, anthropic etc](../proxy/user_keys) \ No newline at end of file diff --git a/litellm/llms/openai/transcriptions/handler.py b/litellm/llms/openai/transcriptions/handler.py index 5e1746319e..d9dd3c123b 100644 --- a/litellm/llms/openai/transcriptions/handler.py +++ b/litellm/llms/openai/transcriptions/handler.py @@ -112,6 +112,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion): api_base=api_base, timeout=timeout, max_retries=max_retries, + client=client, ) ## LOGGING diff --git a/litellm/main.py b/litellm/main.py index cc74080245..ece484f1f2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3409,6 +3409,7 @@ def embedding( # noqa: PLR0915 or custom_llm_provider == "openai" or custom_llm_provider == "together_ai" or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "litellm_proxy" ): api_base = ( api_base @@ -3485,7 +3486,8 @@ def embedding( # noqa: PLR0915 # set API KEY if api_key is None: api_key = ( - litellm.api_key + api_key + or litellm.api_key or litellm.openai_like_key or get_secret_str("OPENAI_LIKE_API_KEY") ) @@ -4596,7 +4598,10 @@ def image_generation( # noqa: PLR0915 client=client, headers=headers, ) - elif custom_llm_provider == "openai": + elif ( + custom_llm_provider == "openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): model_response = openai_chat_completions.image_generation( model=model, prompt=prompt, @@ -5042,8 +5047,7 @@ def transcription( ) elif ( custom_llm_provider == "openai" - or custom_llm_provider == "groq" - or custom_llm_provider == "fireworks_ai" + or custom_llm_provider in litellm.openai_compatible_providers ): api_base = ( api_base @@ -5201,7 +5205,10 @@ def speech( custom_llm_provider=custom_llm_provider, ) response: Optional[HttpxBinaryResponseContent] = None - if custom_llm_provider == "openai": + if ( + custom_llm_provider == "openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): if voice is None or not (isinstance(voice, str)): raise litellm.BadRequestError( message="'voice' is required to be passed as a string for OpenAI TTS", diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index bd9d3df030..6015f533c0 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915 query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[ - Literal["cohere", "together_ai", "azure_ai", "infinity"] + Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"] ] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, @@ -162,7 +162,7 @@ def rerank( # noqa: PLR0915 ) # Implement rerank logic here based on the custom_llm_provider - if _custom_llm_provider == "cohere": + if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy": # Implement Cohere rerank logic api_key: Optional[str] = ( dynamic_api_key or optional_params.api_key or litellm.api_key diff --git a/tests/llm_translation/test_litellm_proxy_provider.py b/tests/llm_translation/test_litellm_proxy_provider.py new file mode 100644 index 0000000000..8484a66dad --- /dev/null +++ b/tests/llm_translation/test_litellm_proxy_provider.py @@ -0,0 +1,376 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + +import litellm +from litellm import completion, embedding +import pytest +from unittest.mock import MagicMock, patch +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler +import pytest_asyncio +from openai import AsyncOpenAI + + +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk(): + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": "Hello world", + } + ] + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + + with patch.object( + openai_client.chat.completions, "create", new=MagicMock() + ) as mock_call: + try: + completion( + model="litellm_proxy/my-vllm-model", + messages=messages, + response_format={"type": "json_object"}, + client=openai_client, + api_base="my-custom-api-base", + hello="world", + ) + except Exception as e: + print(e) + + mock_call.assert_called_once() + + print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) + + assert "hello" in mock_call.call_args.kwargs["extra_body"] + + +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_structured_output(): + from pydantic import BaseModel + + class Result(BaseModel): + answer: str + + litellm.set_verbose = True + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + + with patch.object( + openai_client.chat.completions, "create", new=MagicMock() + ) as mock_call: + try: + litellm.completion( + model="litellm_proxy/openai/gpt-4o", + messages=[ + {"role": "user", "content": "What is the capital of France?"} + ], + api_key="my-test-api-key", + user="test", + response_format=Result, + base_url="https://litellm.ml-serving-internal.scale.com", + client=openai_client, + ) + except Exception as e: + print(e) + + mock_call.assert_called_once() + + print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) + json_schema = mock_call.call_args.kwargs["response_format"] + assert "json_schema" in json_schema + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_embedding(is_async): + litellm.set_verbose = True + litellm._turn_on_debug() + + if is_async: + from openai import AsyncOpenAI + + openai_client = AsyncOpenAI(api_key="fake-key") + mock_method = AsyncMock() + patch_target = openai_client.embeddings.create + else: + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + mock_method = MagicMock() + patch_target = openai_client.embeddings.create + + with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): + try: + if is_async: + await litellm.aembedding( + model="litellm_proxy/my-vllm-model", + input="Hello world", + client=openai_client, + api_base="my-custom-api-base", + ) + else: + litellm.embedding( + model="litellm_proxy/my-vllm-model", + input="Hello world", + client=openai_client, + api_base="my-custom-api-base", + ) + except Exception as e: + print(e) + + mock_method.assert_called_once() + + print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) + + assert "Hello world" == mock_method.call_args.kwargs["input"] + assert "my-vllm-model" == mock_method.call_args.kwargs["model"] + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_image_generation(is_async): + litellm._turn_on_debug() + + if is_async: + from openai import AsyncOpenAI + + openai_client = AsyncOpenAI(api_key="fake-key") + mock_method = AsyncMock() + patch_target = openai_client.images.generate + else: + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + mock_method = MagicMock() + patch_target = openai_client.images.generate + + with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): + try: + if is_async: + response = await litellm.aimage_generation( + model="litellm_proxy/dall-e-3", + prompt="A beautiful sunset over mountains", + client=openai_client, + api_base="my-custom-api-base", + ) + else: + response = litellm.image_generation( + model="litellm_proxy/dall-e-3", + prompt="A beautiful sunset over mountains", + client=openai_client, + api_base="my-custom-api-base", + ) + print("response=", response) + except Exception as e: + print("got error", e) + + mock_method.assert_called_once() + + print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) + + assert ( + "A beautiful sunset over mountains" + == mock_method.call_args.kwargs["prompt"] + ) + assert "dall-e-3" == mock_method.call_args.kwargs["model"] + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_transcription(is_async): + litellm.set_verbose = True + litellm._turn_on_debug() + + if is_async: + from openai import AsyncOpenAI + + openai_client = AsyncOpenAI(api_key="fake-key") + mock_method = AsyncMock() + patch_target = openai_client.audio.transcriptions.create + else: + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + mock_method = MagicMock() + patch_target = openai_client.audio.transcriptions.create + + with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): + try: + if is_async: + await litellm.atranscription( + model="litellm_proxy/whisper-1", + file=b"sample_audio", + client=openai_client, + api_base="my-custom-api-base", + ) + else: + litellm.transcription( + model="litellm_proxy/whisper-1", + file=b"sample_audio", + client=openai_client, + api_base="my-custom-api-base", + ) + except Exception as e: + print(e) + + mock_method.assert_called_once() + + print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) + + assert "whisper-1" == mock_method.call_args.kwargs["model"] + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_speech(is_async): + litellm.set_verbose = True + + if is_async: + from openai import AsyncOpenAI + + openai_client = AsyncOpenAI(api_key="fake-key") + mock_method = AsyncMock() + patch_target = openai_client.audio.speech.create + else: + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + mock_method = MagicMock() + patch_target = openai_client.audio.speech.create + + with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): + try: + if is_async: + await litellm.aspeech( + model="litellm_proxy/tts-1", + input="Hello, this is a test of text to speech", + voice="alloy", + client=openai_client, + api_base="my-custom-api-base", + ) + else: + litellm.speech( + model="litellm_proxy/tts-1", + input="Hello, this is a test of text to speech", + voice="alloy", + client=openai_client, + api_base="my-custom-api-base", + ) + except Exception as e: + print(e) + + mock_method.assert_called_once() + + print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) + + assert ( + "Hello, this is a test of text to speech" + == mock_method.call_args.kwargs["input"] + ) + assert "tts-1" == mock_method.call_args.kwargs["model"] + assert "alloy" == mock_method.call_args.kwargs["voice"] + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_rerank(is_async): + litellm.set_verbose = True + litellm._turn_on_debug() + + if is_async: + client = AsyncHTTPHandler() + mock_method = AsyncMock() + patch_target = client.post + else: + client = HTTPHandler() + mock_method = MagicMock() + patch_target = client.post + + with patch.object(client, "post", new=mock_method): + mock_response = MagicMock() + + # Create a mock response similar to OpenAI's rerank response + mock_response.text = json.dumps( + { + "id": "rerank-123456", + "object": "reranking", + "results": [ + { + "index": 0, + "relevance_score": 0.9, + "document": { + "id": "0", + "text": "Machine learning is a field of study in artificial intelligence", + }, + }, + { + "index": 1, + "relevance_score": 0.2, + "document": { + "id": "1", + "text": "Biology is the study of living organisms", + }, + }, + ], + "model": "rerank-english-v2.0", + "usage": {"prompt_tokens": 10, "total_tokens": 10}, + } + ) + + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json = lambda: json.loads(mock_response.text) + + if is_async: + mock_method.return_value = mock_response + else: + mock_method.return_value = mock_response + + try: + if is_async: + response = await litellm.arerank( + model="litellm_proxy/rerank-english-v2.0", + query="What is machine learning?", + documents=[ + "Machine learning is a field of study in artificial intelligence", + "Biology is the study of living organisms", + ], + client=client, + api_base="my-custom-api-base", + ) + else: + response = litellm.rerank( + model="litellm_proxy/rerank-english-v2.0", + query="What is machine learning?", + documents=[ + "Machine learning is a field of study in artificial intelligence", + "Biology is the study of living organisms", + ], + client=client, + api_base="my-custom-api-base", + ) + except Exception as e: + print(e) + + # Verify the request + mock_method.assert_called_once() + call_args = mock_method.call_args + print("call_args=", call_args) + + # Check that the URL is correct + assert "my-custom-api-base/v1/rerank" == call_args.kwargs["url"] + + # Check that the request body contains the expected data + request_body = json.loads(call_args.kwargs["data"]) + assert request_body["query"] == "What is machine learning?" + assert request_body["model"] == "rerank-english-v2.0" + assert len(request_body["documents"]) == 2 diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index b1995a6b23..b19232ed67 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1819,78 +1819,6 @@ def test_lm_studio_completion(monkeypatch): print(e) -@pytest.mark.asyncio -async def test_litellm_gateway_from_sdk(): - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "Hello world", - } - ] - from openai import OpenAI - - openai_client = OpenAI(api_key="fake-key") - - with patch.object( - openai_client.chat.completions, "create", new=MagicMock() - ) as mock_call: - try: - completion( - model="litellm_proxy/my-vllm-model", - messages=messages, - response_format={"type": "json_object"}, - client=openai_client, - api_base="my-custom-api-base", - hello="world", - ) - except Exception as e: - print(e) - - mock_call.assert_called_once() - - print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) - - assert "hello" in mock_call.call_args.kwargs["extra_body"] - - -@pytest.mark.asyncio -async def test_litellm_gateway_from_sdk_structured_output(): - from pydantic import BaseModel - - class Result(BaseModel): - answer: str - - litellm.set_verbose = True - from openai import OpenAI - - openai_client = OpenAI(api_key="fake-key") - - with patch.object( - openai_client.chat.completions, "create", new=MagicMock() - ) as mock_call: - try: - litellm.completion( - model="litellm_proxy/openai/gpt-4o", - messages=[ - {"role": "user", "content": "What is the capital of France?"} - ], - api_key="my-test-api-key", - user="test", - response_format=Result, - base_url="https://litellm.ml-serving-internal.scale.com", - client=openai_client, - ) - except Exception as e: - print(e) - - mock_call.assert_called_once() - - print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) - json_schema = mock_call.call_args.kwargs["response_format"] - assert "json_schema" in json_schema - - # ################### Hugging Face Conversational models ######################## # def hf_test_completion_conv(): # try: