mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into empower-functions-v1
This commit is contained in:
commit
cb025a7f26
526 changed files with 258624 additions and 22958 deletions
|
@ -1,21 +1,29 @@
|
|||
import sys, os, json
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
from unittest.mock import patch, MagicMock
|
||||
) # Adds-the parent directory to the system path
|
||||
|
||||
# litellm.num_retries=3
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
# litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
|
@ -38,7 +46,7 @@ def reset_callbacks():
|
|||
@pytest.mark.skip(reason="Local test")
|
||||
def test_response_model_none():
|
||||
"""
|
||||
Addresses:https://github.com/BerriAI/litellm/issues/2972
|
||||
Addresses: https://github.com/BerriAI/litellm/issues/2972
|
||||
"""
|
||||
x = completion(
|
||||
model="mymodel",
|
||||
|
@ -113,6 +121,27 @@ def test_null_role_response():
|
|||
assert response.choices[0].message.role == "assistant"
|
||||
|
||||
|
||||
def test_completion_azure_ai_command_r():
|
||||
try:
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||
|
||||
response: litellm.ModelResponse = completion(
|
||||
model="azure_ai/command-r-plus",
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
) # type: ignore
|
||||
|
||||
assert "azure_ai" in response.model
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_azure_command_r():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
@ -152,29 +181,63 @@ async def test_completion_databricks(sync_mode):
|
|||
response_format_tests(response=response)
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="local test")
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
def predibase_mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"generated_text": " Is it to find happiness, to achieve success,",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"prompt_tokens": 8,
|
||||
"generated_tokens": 10,
|
||||
"seed": None,
|
||||
"prefill": [],
|
||||
"tokens": [
|
||||
{"id": 2209, "text": " Is", "logprob": -1.7568359, "special": False},
|
||||
{"id": 433, "text": " it", "logprob": -0.2220459, "special": False},
|
||||
{"id": 311, "text": " to", "logprob": -0.6928711, "special": False},
|
||||
{"id": 1505, "text": " find", "logprob": -0.6425781, "special": False},
|
||||
{
|
||||
"id": 23871,
|
||||
"text": " happiness",
|
||||
"logprob": -0.07519531,
|
||||
"special": False,
|
||||
},
|
||||
{"id": 11, "text": ",", "logprob": -0.07110596, "special": False},
|
||||
{"id": 311, "text": " to", "logprob": -0.79296875, "special": False},
|
||||
{
|
||||
"id": 11322,
|
||||
"text": " achieve",
|
||||
"logprob": -0.7602539,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 2450,
|
||||
"text": " success",
|
||||
"logprob": -0.03656006,
|
||||
"special": False,
|
||||
},
|
||||
{"id": 11, "text": ",", "logprob": -0.0011510849, "special": False},
|
||||
],
|
||||
},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="local only test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_predibase(sync_mode):
|
||||
async def test_completion_predibase():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
with patch("requests.post", side_effect=predibase_mock_post):
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
@ -293,7 +356,11 @@ def test_completion_claude_3():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude_3_function_call():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
)
|
||||
def test_completion_claude_3_function_call(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -324,16 +391,17 @@ def test_completion_claude_3_function_call():
|
|||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
# Add any assertions, here to check response args
|
||||
# Add any assertions here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
|
@ -357,16 +425,114 @@ def test_completion_claude_3_function_call():
|
|||
)
|
||||
# In the second response, Claude should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
drop_params=True,
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"model, api_key, api_base",
|
||||
[
|
||||
("gpt-3.5-turbo", None, None),
|
||||
("claude-3-opus-20240229", None, None),
|
||||
("command-r", None, None),
|
||||
("anthropic.claude-3-sonnet-20240229-v1:0", None, None),
|
||||
(
|
||||
"azure_ai/command-r-plus",
|
||||
os.getenv("AZURE_COHERE_API_KEY"),
|
||||
os.getenv("AZURE_COHERE_API_BASE"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_function_invoke(model, sync_mode, api_key, api_base):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
# User asks for their name and weather in San Francisco
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, what is your name and can you tell me the weather?",
|
||||
},
|
||||
# Assistant replies with a tool call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco, CA"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
# The result of the tool call is added to the history
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "27 degrees celsius and clear in San Francisco, CA",
|
||||
},
|
||||
# Now the assistant can reply with the result of the tool call.
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
else:
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"response: {response}")
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "429 Quota exceeded" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_no_content_error():
|
||||
"""
|
||||
|
@ -437,6 +603,7 @@ async def test_anthropic_no_content_error():
|
|||
def test_gemini_completion_call_error():
|
||||
try:
|
||||
print("test completion + streaming")
|
||||
litellm.num_retries = 3
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "what is the capital of congo?"}]
|
||||
response = completion(
|
||||
|
@ -516,6 +683,7 @@ def test_completion_cohere_command_r_plus_function_call():
|
|||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
force_single_step=True,
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
|
@ -651,8 +819,10 @@ def test_completion_claude_3_base64():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="issue getting wikipedia images in ci/cd")
|
||||
def test_completion_claude_3_function_plus_image():
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gemini/gemini-1.5-flash"] # "claude-3-sonnet-20240229",
|
||||
)
|
||||
def test_completion_function_plus_image(model):
|
||||
litellm.set_verbose = True
|
||||
|
||||
image_content = [
|
||||
|
@ -660,7 +830,7 @@ def test_completion_claude_3_function_plus_image():
|
|||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
|
@ -676,7 +846,7 @@ def test_completion_claude_3_function_plus_image():
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "text",
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
|
@ -696,7 +866,7 @@ def test_completion_claude_3_function_plus_image():
|
|||
]
|
||||
|
||||
response = completion(
|
||||
model="claude-3-sonnet-20240229",
|
||||
model=model,
|
||||
messages=[image_message],
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
|
@ -706,7 +876,11 @@ def test_completion_claude_3_function_plus_image():
|
|||
print(response)
|
||||
|
||||
|
||||
def test_completion_azure_mistral_large_function_calling():
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
["azure", "azure_ai"],
|
||||
)
|
||||
def test_completion_azure_mistral_large_function_calling(provider):
|
||||
"""
|
||||
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
|
||||
"""
|
||||
|
@ -737,8 +911,9 @@ def test_completion_azure_mistral_large_function_calling():
|
|||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
|
||||
response = completion(
|
||||
model="azure/mistral-large-latest",
|
||||
model="{}/mistral-large-latest".format(provider),
|
||||
api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
|
||||
api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
|
||||
messages=messages,
|
||||
|
@ -776,6 +951,34 @@ def test_completion_mistral_api():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_codestral_chat_api():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.acompletion(
|
||||
model="codestral/codestral-latest",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how's it going?",
|
||||
}
|
||||
],
|
||||
temperature=0.0,
|
||||
top_p=1,
|
||||
max_tokens=10,
|
||||
safe_prompt=False,
|
||||
seed=12,
|
||||
)
|
||||
# Add any assertions here to-check the response
|
||||
print(response)
|
||||
|
||||
# cost = litellm.completion_cost(completion_response=response)
|
||||
# print("cost to make mistral completion=", cost)
|
||||
# assert cost > 0.0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_mistral_api_mistral_large_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
|
@ -1348,7 +1551,7 @@ def test_hf_test_completion_tgi():
|
|||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
# Add any assertions-here to check the response
|
||||
print(response)
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
|
@ -1358,6 +1561,43 @@ def test_hf_test_completion_tgi():
|
|||
|
||||
# hf_test_completion_tgi()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compatible_custom_api_base(provider):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello world",
|
||||
}
|
||||
]
|
||||
from openai import OpenAI
|
||||
|
||||
openai_client = OpenAI(api_key="fake-key")
|
||||
|
||||
with patch.object(
|
||||
openai_client.chat.completions, "create", new=MagicMock()
|
||||
) as mock_call:
|
||||
try:
|
||||
response = completion(
|
||||
model="openai/my-vllm-model",
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
client=openai_client,
|
||||
api_base="my-custom-api-base",
|
||||
hello="world",
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||
|
||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||
|
||||
|
||||
# ################### Hugging Face Conversational models ########################
|
||||
# def hf_test_completion_conv():
|
||||
# try:
|
||||
|
@ -1390,7 +1630,6 @@ def test_hf_test_completion_tgi():
|
|||
|
||||
|
||||
def mock_post(url, data=None, json=None, headers=None):
|
||||
|
||||
print(f"url={url}")
|
||||
if "text-classification" in url:
|
||||
raise Exception("Model not found")
|
||||
|
@ -1432,7 +1671,9 @@ def test_ollama_image():
|
|||
data is untouched.
|
||||
"""
|
||||
|
||||
import io, base64
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def mock_post(url, **kwargs):
|
||||
|
@ -2152,8 +2393,10 @@ def test_completion_azure_key_completion_arg():
|
|||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
api_key=old_key,
|
||||
logprobs=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
print("Hidden Params", response._hidden_params)
|
||||
|
@ -2212,12 +2455,6 @@ async def test_re_use_azure_async_client():
|
|||
pytest.fail("got Exception", e)
|
||||
|
||||
|
||||
# import asyncio
|
||||
# asyncio.run(
|
||||
# test_re_use_azure_async_client()
|
||||
# )
|
||||
|
||||
|
||||
def test_re_use_openaiClient():
|
||||
try:
|
||||
print("gpt-3.5 with client test\n\n")
|
||||
|
@ -2237,9 +2474,6 @@ def test_re_use_openaiClient():
|
|||
pytest.fail("got Exception", e)
|
||||
|
||||
|
||||
# test_re_use_openaiClient()
|
||||
|
||||
|
||||
def test_completion_azure():
|
||||
try:
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
|
@ -2251,7 +2485,7 @@ def test_completion_azure():
|
|||
api_key="os.environ/AZURE_API_KEY",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
## Test azure flag for backwards compat
|
||||
## Test azure flag for backwards-compat
|
||||
# response = completion(
|
||||
# model="chatgpt-v-2",
|
||||
# messages=messages,
|
||||
|
@ -2471,6 +2705,8 @@ async def test_completion_replicate_llama3(sync_mode):
|
|||
# Add any assertions here to check the response
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
response_format_tests(response=response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -2535,6 +2771,7 @@ def test_replicate_custom_prompt_dict():
|
|||
"content": "what is yc write 1 paragraph",
|
||||
}
|
||||
],
|
||||
mock_response="Hello world",
|
||||
repetition_penalty=0.1,
|
||||
num_retries=3,
|
||||
)
|
||||
|
@ -3266,6 +3503,7 @@ def test_mistral_anyscale_stream():
|
|||
|
||||
|
||||
#### Test A121 ###################
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
def test_completion_ai21():
|
||||
print("running ai21 j2light test")
|
||||
litellm.set_verbose = True
|
||||
|
@ -3283,17 +3521,54 @@ def test_completion_ai21():
|
|||
# test_completion_ai21()
|
||||
# test_completion_ai21()
|
||||
## test deep infra
|
||||
def test_completion_deep_infra():
|
||||
@pytest.mark.parametrize("drop_params", [True, False])
|
||||
def test_completion_deep_infra(drop_params):
|
||||
litellm.set_verbose = False
|
||||
model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name, messages=messages, temperature=0, max_tokens=10
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
drop_params=drop_params,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
if drop_params is True:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_deep_infra()
|
||||
|
@ -3320,13 +3595,93 @@ def test_completion_deep_infra_mistral():
|
|||
# test_completion_deep_infra_mistral()
|
||||
|
||||
|
||||
# Gemini tests
|
||||
def test_completion_gemini():
|
||||
@pytest.mark.skip(reason="Local test - don't have a volcengine account as yet")
|
||||
def test_completion_volcengine():
|
||||
litellm.set_verbose = True
|
||||
model_name = "gemini/gemini-1.5-pro-latest"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
model_name = "volcengine/<OUR_ENDPOINT_ID>"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages)
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
api_key="<OUR_API_KEY>",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
||||
except litellm.exceptions.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_nvidia_nim():
|
||||
model_name = "nvidia_nim/databricks/dbrx-instruct"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.1,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
except litellm.exceptions.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# Gemini tests
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "gemini-1.0-pro",
|
||||
"gemini-1.5-pro",
|
||||
# "gemini-1.5-flash",
|
||||
],
|
||||
)
|
||||
def test_completion_gemini(model):
|
||||
litellm.set_verbose = True
|
||||
model_name = "gemini/{}".format(model)
|
||||
messages = [
|
||||
{"role": "system", "content": "Be a good bot!"},
|
||||
{"role": "user", "content": "Hey, how's it going?"},
|
||||
]
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
safety_settings=[
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
],
|
||||
)
|
||||
# Add any assertions,here to check the response
|
||||
print(response)
|
||||
assert response.choices[0]["index"] == 0
|
||||
|
@ -3416,6 +3771,7 @@ def test_completion_palm_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||
def test_completion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
|
@ -3436,6 +3792,7 @@ def test_completion_watsonx():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip test. account deleted.")
|
||||
def test_completion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
|
@ -3503,6 +3860,7 @@ def test_unified_auth_params(provider, model, project, region_name, token):
|
|||
assert value in translated_optional_params
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
|
@ -3523,6 +3881,7 @@ async def test_acompletion_watsonx():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue