mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge branch 'main' into litellm_redis_team_object
This commit is contained in:
commit
c2086300b7
34 changed files with 1182 additions and 232 deletions
|
@ -2573,21 +2573,17 @@ def test_completion_azure_extra_headers():
|
|||
http_client = Client()
|
||||
|
||||
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_API_BASE"),
|
||||
api_version=litellm.AZURE_DEFAULT_API_VERSION,
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
http_client=http_client,
|
||||
)
|
||||
litellm.client_session = http_client
|
||||
try:
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
client=client,
|
||||
api_base=os.getenv("AZURE_API_BASE"),
|
||||
api_version="2023-07-01-preview",
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
extra_headers={
|
||||
"Authorization": "my-bad-key",
|
||||
"Ocp-Apim-Subscription-Key": "hello-world-testing",
|
||||
"api-key": "my-bad-key",
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
|
@ -2603,8 +2599,10 @@ def test_completion_azure_extra_headers():
|
|||
print(request.url) # This will print the full URL
|
||||
print(request.headers) # This will print the full URL
|
||||
auth_header = request.headers.get("Authorization")
|
||||
apim_key = request.headers.get("Ocp-Apim-Subscription-Key")
|
||||
print(auth_header)
|
||||
assert auth_header == "my-bad-key"
|
||||
assert apim_key == "hello-world-testing"
|
||||
|
||||
|
||||
def test_completion_azure_ad_token():
|
||||
|
@ -2613,18 +2611,37 @@ def test_completion_azure_ad_token():
|
|||
# If you want to remove it, speak to Ishaan!
|
||||
# Ishaan will be very disappointed if this test is removed -> this is a standard way to pass api_key + the router + proxy use this
|
||||
from httpx import Client
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
# api_key="my-fake-ad-token",
|
||||
azure_ad_token=os.getenv("AZURE_API_KEY"),
|
||||
)
|
||||
print(response)
|
||||
litellm.set_verbose = True
|
||||
|
||||
old_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ.pop("AZURE_API_KEY", None)
|
||||
|
||||
http_client = Client()
|
||||
|
||||
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
|
||||
litellm.client_session = http_client
|
||||
try:
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
azure_ad_token="my-special-token",
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
os.environ["AZURE_API_KEY"] = old_key
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request = mock_client.call_args[0][0]
|
||||
print(request.method) # This will print 'POST'
|
||||
print(request.url) # This will print the full URL
|
||||
print(request.headers) # This will print the full URL
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header == "Bearer my-special-token"
|
||||
|
||||
|
||||
def test_completion_azure_key_completion_arg():
|
||||
|
|
302
litellm/tests/test_custom_llm.py
Normal file
302
litellm/tests/test_custom_llm.py
Normal file
|
@ -0,0 +1,302 @@
|
|||
# What is this?
|
||||
## Unit tests for the CustomLLM class
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
CustomLLM,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
acompletion,
|
||||
completion,
|
||||
get_llm_provider,
|
||||
)
|
||||
from litellm.utils import ModelResponseIterator
|
||||
|
||||
|
||||
class CustomModelResponseIterator:
|
||||
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
||||
return GenericStreamingChunk(
|
||||
text="hello world",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=ChatCompletionUsageBlock(
|
||||
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk: Any = self.streaming_response.__next__() # type: ignore
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||
return self.streaming_response
|
||||
|
||||
async def __anext__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": "Hello world",
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=generic_streaming_chunk # type: ignore
|
||||
)
|
||||
custom_iterator = CustomModelResponseIterator(
|
||||
streaming_response=completion_stream
|
||||
)
|
||||
return custom_iterator
|
||||
|
||||
async def astreaming( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": "Hello world",
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk # type: ignore
|
||||
|
||||
|
||||
def test_get_llm_provider():
|
||||
""""""
|
||||
from litellm.utils import custom_llm_setup
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
||||
custom_llm_setup()
|
||||
|
||||
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
||||
|
||||
assert provider == "custom_llm"
|
||||
|
||||
|
||||
def test_simple_completion():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_acompletion():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = await acompletion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
|
||||
|
||||
def test_simple_completion_streaming():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
print(chunk)
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert isinstance(chunk.choices[0].delta.content, str)
|
||||
else:
|
||||
assert chunk.choices[0].finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_completion_async_streaming():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = await litellm.acompletion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk)
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert isinstance(chunk.choices[0].delta.content, str)
|
||||
else:
|
||||
assert chunk.choices[0].finish_reason == "stop"
|
|
@ -206,6 +206,9 @@ def test_openai_azure_embedding_with_oidc_and_cf():
|
|||
os.environ["AZURE_TENANT_ID"] = "17c0a27a-1246-4aa1-a3b6-d294e80e783c"
|
||||
os.environ["AZURE_CLIENT_ID"] = "4faf5422-b2bd-45e8-a6d7-46543a38acd0"
|
||||
|
||||
old_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ.pop("AZURE_API_KEY", None)
|
||||
|
||||
try:
|
||||
response = embedding(
|
||||
model="azure/text-embedding-ada-002",
|
||||
|
@ -218,6 +221,8 @@ def test_openai_azure_embedding_with_oidc_and_cf():
|
|||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
finally:
|
||||
os.environ["AZURE_API_KEY"] = old_key
|
||||
|
||||
|
||||
def test_openai_azure_embedding_optional_arg(mocker):
|
||||
|
@ -673,17 +678,3 @@ async def test_databricks_embeddings(sync_mode):
|
|||
# print(response)
|
||||
|
||||
# local_proxy_embeddings()
|
||||
|
||||
|
||||
def test_embedding_azure_ad_token():
|
||||
# this tests if we can pass api_key to completion, when it's not in the env.
|
||||
# DO NOT REMOVE THIS TEST. No MATTER WHAT Happens!
|
||||
# If you want to remove it, speak to Ishaan!
|
||||
# Ishaan will be very disappointed if this test is removed -> this is a standard way to pass api_key + the router + proxy use this
|
||||
|
||||
response = embedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
azure_ad_token=os.getenv("AZURE_API_KEY"),
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
import sys, os
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
|
@ -21,6 +25,12 @@ def test_get_llm_provider():
|
|||
# test_get_llm_provider()
|
||||
|
||||
|
||||
def test_get_llm_provider_gpt_instruct():
|
||||
_, response, _, _ = litellm.get_llm_provider(model="gpt-3.5-turbo-instruct-0914")
|
||||
|
||||
assert response == "text-completion-openai"
|
||||
|
||||
|
||||
def test_get_llm_provider_mistral_custom_api_base():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="mistral/mistral-large-fr",
|
||||
|
|
|
@ -3840,7 +3840,26 @@ def test_completion_chatgpt_prompt():
|
|||
try:
|
||||
print("\n gpt3.5 test\n")
|
||||
response = text_completion(
|
||||
model="gpt-3.5-turbo", prompt="What's the weather in SF?"
|
||||
model="openai/gpt-3.5-turbo", prompt="What's the weather in SF?"
|
||||
)
|
||||
print(response)
|
||||
response_str = response["choices"][0]["text"]
|
||||
print("\n", response.choices)
|
||||
print("\n", response.choices[0])
|
||||
# print(response.choices[0].text)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_chatgpt_prompt()
|
||||
|
||||
|
||||
def test_completion_gpt_instruct():
|
||||
try:
|
||||
response = text_completion(
|
||||
model="gpt-3.5-turbo-instruct-0914",
|
||||
prompt="What's the weather in SF?",
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
print(response)
|
||||
response_str = response["choices"][0]["text"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue