mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Bug Fix) Using LiteLLM Python SDK with model=litellm_proxy/
for embedding, image_generation, transcription, speech, rerank (#8815)
* test_litellm_gateway_from_sdk * fix embedding check for openai * test litellm proxy provider * fix image generation openai compatible models * fix litellm.transcription * test_litellm_gateway_from_sdk_rerank * docs litellm python sdk * docs litellm python sdk with proxy * test_litellm_gateway_from_sdk_rerank * ci/cd run again * test_litellm_gateway_from_sdk_image_generation * test_litellm_gateway_from_sdk_embedding * test_litellm_gateway_from_sdk_embedding
This commit is contained in:
parent
ef22209a15
commit
f9cee4c46b
6 changed files with 466 additions and 83 deletions
|
@ -3,13 +3,15 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# LiteLLM Proxy (LLM Gateway)
|
# LiteLLM Proxy (LLM Gateway)
|
||||||
|
|
||||||
:::tip
|
|
||||||
|
|
||||||
[LiteLLM Providers a **self hosted** proxy server (AI Gateway)](../simple_proxy) to call all the LLMs in the OpenAI format
|
| Property | Details |
|
||||||
|
|-------|-------|
|
||||||
|
| Description | LiteLLM Proxy is an OpenAI-compatible gateway that allows you to interact with multiple LLM providers through a unified API. Simply use the `litellm_proxy/` prefix before the model name to route your requests through the proxy. |
|
||||||
|
| Provider Route on LiteLLM | `litellm_proxy/` (add this prefix to the model name, to route any requests to litellm_proxy - e.g. `litellm_proxy/your-model-name`) |
|
||||||
|
| Setup LiteLLM Gateway | [LiteLLM Gateway ↗](../simple_proxy) |
|
||||||
|
| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/rerank` |
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
|
|
||||||
|
|
||||||
## Required Variables
|
## Required Variables
|
||||||
|
|
||||||
|
@ -83,7 +85,76 @@ for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="litellm_proxy/your-embedding-model",
|
||||||
|
input="Hello world",
|
||||||
|
api_base="your-litellm-proxy-url",
|
||||||
|
api_key="your-litellm-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = litellm.image_generation(
|
||||||
|
model="litellm_proxy/dall-e-3",
|
||||||
|
prompt="A beautiful sunset over mountains",
|
||||||
|
api_base="your-litellm-proxy-url",
|
||||||
|
api_key="your-litellm-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Audio Transcription
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = litellm.transcription(
|
||||||
|
model="litellm_proxy/whisper-1",
|
||||||
|
file="your-audio-file",
|
||||||
|
api_base="your-litellm-proxy-url",
|
||||||
|
api_key="your-litellm-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Text to Speech
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = litellm.speech(
|
||||||
|
model="litellm_proxy/tts-1",
|
||||||
|
input="Hello world",
|
||||||
|
api_base="your-litellm-proxy-url",
|
||||||
|
api_key="your-litellm-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Rerank
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = litellm.rerank(
|
||||||
|
model="litellm_proxy/rerank-english-v2.0",
|
||||||
|
query="What is machine learning?",
|
||||||
|
documents=[
|
||||||
|
"Machine learning is a field of study in artificial intelligence",
|
||||||
|
"Biology is the study of living organisms"
|
||||||
|
],
|
||||||
|
api_base="your-litellm-proxy-url",
|
||||||
|
api_key="your-litellm-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
## **Usage with Langchain, LLamaindex, OpenAI Js, Anthropic SDK, Instructor**
|
## **Usage with Langchain, LLamaindex, OpenAI Js, Anthropic SDK, Instructor**
|
||||||
|
|
||||||
#### [Follow this doc to see how to use litellm proxy with langchain, llamaindex, anthropic etc](../proxy/user_keys)
|
#### [Follow this doc to see how to use litellm proxy with langchain, llamaindex, anthropic etc](../proxy/user_keys)
|
|
@ -112,6 +112,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -3409,6 +3409,7 @@ def embedding( # noqa: PLR0915
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "together_ai"
|
or custom_llm_provider == "together_ai"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
|
or custom_llm_provider == "litellm_proxy"
|
||||||
):
|
):
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
|
@ -3485,7 +3486,8 @@ def embedding( # noqa: PLR0915
|
||||||
# set API KEY
|
# set API KEY
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = (
|
api_key = (
|
||||||
litellm.api_key
|
api_key
|
||||||
|
or litellm.api_key
|
||||||
or litellm.openai_like_key
|
or litellm.openai_like_key
|
||||||
or get_secret_str("OPENAI_LIKE_API_KEY")
|
or get_secret_str("OPENAI_LIKE_API_KEY")
|
||||||
)
|
)
|
||||||
|
@ -4596,7 +4598,10 @@ def image_generation( # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "openai":
|
elif (
|
||||||
|
custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
):
|
||||||
model_response = openai_chat_completions.image_generation(
|
model_response = openai_chat_completions.image_generation(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -5042,8 +5047,7 @@ def transcription(
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
or custom_llm_provider == "fireworks_ai"
|
|
||||||
):
|
):
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
|
@ -5201,7 +5205,10 @@ def speech(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
response: Optional[HttpxBinaryResponseContent] = None
|
response: Optional[HttpxBinaryResponseContent] = None
|
||||||
if custom_llm_provider == "openai":
|
if (
|
||||||
|
custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
):
|
||||||
if voice is None or not (isinstance(voice, str)):
|
if voice is None or not (isinstance(voice, str)):
|
||||||
raise litellm.BadRequestError(
|
raise litellm.BadRequestError(
|
||||||
message="'voice' is required to be passed as a string for OpenAI TTS",
|
message="'voice' is required to be passed as a string for OpenAI TTS",
|
||||||
|
|
|
@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Union[str, Dict[str, Any]]],
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
custom_llm_provider: Optional[
|
custom_llm_provider: Optional[
|
||||||
Literal["cohere", "together_ai", "azure_ai", "infinity"]
|
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
|
||||||
] = None,
|
] = None,
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
rank_fields: Optional[List[str]] = None,
|
rank_fields: Optional[List[str]] = None,
|
||||||
|
@ -162,7 +162,7 @@ def rerank( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
# Implement rerank logic here based on the custom_llm_provider
|
# Implement rerank logic here based on the custom_llm_provider
|
||||||
if _custom_llm_provider == "cohere":
|
if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy":
|
||||||
# Implement Cohere rerank logic
|
# Implement Cohere rerank logic
|
||||||
api_key: Optional[str] = (
|
api_key: Optional[str] = (
|
||||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||||
|
|
376
tests/llm_translation/test_litellm_proxy_provider.py
Normal file
376
tests/llm_translation/test_litellm_proxy_provider.py
Normal file
|
@ -0,0 +1,376 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion, embedding
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
import pytest_asyncio
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello world",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.chat.completions, "create", new=MagicMock()
|
||||||
|
) as mock_call:
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
model="litellm_proxy/my-vllm-model",
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
hello="world",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||||
|
|
||||||
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_structured_output():
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class Result(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.chat.completions, "create", new=MagicMock()
|
||||||
|
) as mock_call:
|
||||||
|
try:
|
||||||
|
litellm.completion(
|
||||||
|
model="litellm_proxy/openai/gpt-4o",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What is the capital of France?"}
|
||||||
|
],
|
||||||
|
api_key="my-test-api-key",
|
||||||
|
user="test",
|
||||||
|
response_format=Result,
|
||||||
|
base_url="https://litellm.ml-serving-internal.scale.com",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||||
|
json_schema = mock_call.call_args.kwargs["response_format"]
|
||||||
|
assert "json_schema" in json_schema
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_embedding(is_async):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
openai_client = AsyncOpenAI(api_key="fake-key")
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
patch_target = openai_client.embeddings.create
|
||||||
|
else:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
mock_method = MagicMock()
|
||||||
|
patch_target = openai_client.embeddings.create
|
||||||
|
|
||||||
|
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
|
||||||
|
try:
|
||||||
|
if is_async:
|
||||||
|
await litellm.aembedding(
|
||||||
|
model="litellm_proxy/my-vllm-model",
|
||||||
|
input="Hello world",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
litellm.embedding(
|
||||||
|
model="litellm_proxy/my-vllm-model",
|
||||||
|
input="Hello world",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_method.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
|
||||||
|
|
||||||
|
assert "Hello world" == mock_method.call_args.kwargs["input"]
|
||||||
|
assert "my-vllm-model" == mock_method.call_args.kwargs["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_image_generation(is_async):
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
openai_client = AsyncOpenAI(api_key="fake-key")
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
patch_target = openai_client.images.generate
|
||||||
|
else:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
mock_method = MagicMock()
|
||||||
|
patch_target = openai_client.images.generate
|
||||||
|
|
||||||
|
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
|
||||||
|
try:
|
||||||
|
if is_async:
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
model="litellm_proxy/dall-e-3",
|
||||||
|
prompt="A beautiful sunset over mountains",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = litellm.image_generation(
|
||||||
|
model="litellm_proxy/dall-e-3",
|
||||||
|
prompt="A beautiful sunset over mountains",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
print("response=", response)
|
||||||
|
except Exception as e:
|
||||||
|
print("got error", e)
|
||||||
|
|
||||||
|
mock_method.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"A beautiful sunset over mountains"
|
||||||
|
== mock_method.call_args.kwargs["prompt"]
|
||||||
|
)
|
||||||
|
assert "dall-e-3" == mock_method.call_args.kwargs["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_transcription(is_async):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
openai_client = AsyncOpenAI(api_key="fake-key")
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
patch_target = openai_client.audio.transcriptions.create
|
||||||
|
else:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
mock_method = MagicMock()
|
||||||
|
patch_target = openai_client.audio.transcriptions.create
|
||||||
|
|
||||||
|
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
|
||||||
|
try:
|
||||||
|
if is_async:
|
||||||
|
await litellm.atranscription(
|
||||||
|
model="litellm_proxy/whisper-1",
|
||||||
|
file=b"sample_audio",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
litellm.transcription(
|
||||||
|
model="litellm_proxy/whisper-1",
|
||||||
|
file=b"sample_audio",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_method.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
|
||||||
|
|
||||||
|
assert "whisper-1" == mock_method.call_args.kwargs["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_speech(is_async):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
openai_client = AsyncOpenAI(api_key="fake-key")
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
patch_target = openai_client.audio.speech.create
|
||||||
|
else:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
mock_method = MagicMock()
|
||||||
|
patch_target = openai_client.audio.speech.create
|
||||||
|
|
||||||
|
with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method):
|
||||||
|
try:
|
||||||
|
if is_async:
|
||||||
|
await litellm.aspeech(
|
||||||
|
model="litellm_proxy/tts-1",
|
||||||
|
input="Hello, this is a test of text to speech",
|
||||||
|
voice="alloy",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
litellm.speech(
|
||||||
|
model="litellm_proxy/tts-1",
|
||||||
|
input="Hello, this is a test of text to speech",
|
||||||
|
voice="alloy",
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_method.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_method.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"Hello, this is a test of text to speech"
|
||||||
|
== mock_method.call_args.kwargs["input"]
|
||||||
|
)
|
||||||
|
assert "tts-1" == mock_method.call_args.kwargs["model"]
|
||||||
|
assert "alloy" == mock_method.call_args.kwargs["voice"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk_rerank(is_async):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
patch_target = client.post
|
||||||
|
else:
|
||||||
|
client = HTTPHandler()
|
||||||
|
mock_method = MagicMock()
|
||||||
|
patch_target = client.post
|
||||||
|
|
||||||
|
with patch.object(client, "post", new=mock_method):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
# Create a mock response similar to OpenAI's rerank response
|
||||||
|
mock_response.text = json.dumps(
|
||||||
|
{
|
||||||
|
"id": "rerank-123456",
|
||||||
|
"object": "reranking",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.9,
|
||||||
|
"document": {
|
||||||
|
"id": "0",
|
||||||
|
"text": "Machine learning is a field of study in artificial intelligence",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"relevance_score": 0.2,
|
||||||
|
"document": {
|
||||||
|
"id": "1",
|
||||||
|
"text": "Biology is the study of living organisms",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"model": "rerank-english-v2.0",
|
||||||
|
"usage": {"prompt_tokens": 10, "total_tokens": 10},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json = lambda: json.loads(mock_response.text)
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
mock_method.return_value = mock_response
|
||||||
|
else:
|
||||||
|
mock_method.return_value = mock_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_async:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="litellm_proxy/rerank-english-v2.0",
|
||||||
|
query="What is machine learning?",
|
||||||
|
documents=[
|
||||||
|
"Machine learning is a field of study in artificial intelligence",
|
||||||
|
"Biology is the study of living organisms",
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = litellm.rerank(
|
||||||
|
model="litellm_proxy/rerank-english-v2.0",
|
||||||
|
query="What is machine learning?",
|
||||||
|
documents=[
|
||||||
|
"Machine learning is a field of study in artificial intelligence",
|
||||||
|
"Biology is the study of living organisms",
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
# Verify the request
|
||||||
|
mock_method.assert_called_once()
|
||||||
|
call_args = mock_method.call_args
|
||||||
|
print("call_args=", call_args)
|
||||||
|
|
||||||
|
# Check that the URL is correct
|
||||||
|
assert "my-custom-api-base/v1/rerank" == call_args.kwargs["url"]
|
||||||
|
|
||||||
|
# Check that the request body contains the expected data
|
||||||
|
request_body = json.loads(call_args.kwargs["data"])
|
||||||
|
assert request_body["query"] == "What is machine learning?"
|
||||||
|
assert request_body["model"] == "rerank-english-v2.0"
|
||||||
|
assert len(request_body["documents"]) == 2
|
|
@ -1819,78 +1819,6 @@ def test_lm_studio_completion(monkeypatch):
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_litellm_gateway_from_sdk():
|
|
||||||
litellm.set_verbose = True
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello world",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
openai_client = OpenAI(api_key="fake-key")
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
openai_client.chat.completions, "create", new=MagicMock()
|
|
||||||
) as mock_call:
|
|
||||||
try:
|
|
||||||
completion(
|
|
||||||
model="litellm_proxy/my-vllm-model",
|
|
||||||
messages=messages,
|
|
||||||
response_format={"type": "json_object"},
|
|
||||||
client=openai_client,
|
|
||||||
api_base="my-custom-api-base",
|
|
||||||
hello="world",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
|
||||||
|
|
||||||
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
|
||||||
|
|
||||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_litellm_gateway_from_sdk_structured_output():
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Result(BaseModel):
|
|
||||||
answer: str
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
openai_client = OpenAI(api_key="fake-key")
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
openai_client.chat.completions, "create", new=MagicMock()
|
|
||||||
) as mock_call:
|
|
||||||
try:
|
|
||||||
litellm.completion(
|
|
||||||
model="litellm_proxy/openai/gpt-4o",
|
|
||||||
messages=[
|
|
||||||
{"role": "user", "content": "What is the capital of France?"}
|
|
||||||
],
|
|
||||||
api_key="my-test-api-key",
|
|
||||||
user="test",
|
|
||||||
response_format=Result,
|
|
||||||
base_url="https://litellm.ml-serving-internal.scale.com",
|
|
||||||
client=openai_client,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
|
||||||
|
|
||||||
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
|
||||||
json_schema = mock_call.call_args.kwargs["response_format"]
|
|
||||||
assert "json_schema" in json_schema
|
|
||||||
|
|
||||||
|
|
||||||
# ################### Hugging Face Conversational models ########################
|
# ################### Hugging Face Conversational models ########################
|
||||||
# def hf_test_completion_conv():
|
# def hf_test_completion_conv():
|
||||||
# try:
|
# try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue