forked from phoenix/litellm-mirror
fix(tests): fixing response objects for testing
This commit is contained in:
parent
9776126c8d
commit
8a3b771e50
6 changed files with 188 additions and 104 deletions
|
@ -125,11 +125,18 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(status_code=response.status_code, message=response.text)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
|
|
@ -56,7 +56,7 @@ from .llms.azure import AzureChatCompletion
|
|||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict, Union
|
||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -79,6 +79,35 @@ openai_text_completions = OpenAITextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
class LiteLLM:
|
||||
|
||||
def __init__(self, *,
|
||||
api_key=None,
|
||||
organization: str | None = None,
|
||||
base_url: str = None,
|
||||
timeout: Union[float, None] = 600,
|
||||
max_retries: int | None = litellm.num_retries,
|
||||
default_headers: Mapping[str, str] | None = None,):
|
||||
self.params = locals()
|
||||
self.chat = Chat(self.params)
|
||||
|
||||
class Chat():
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.completions = Completions(self.params)
|
||||
|
||||
class Completions():
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
def create(self, model, messages, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
response = completion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
|
||||
async def acompletion(*args, **kwargs):
|
||||
"""
|
||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
|
@ -240,7 +269,7 @@ def completion(
|
|||
deployment_id = None,
|
||||
|
||||
# set api_base, api_version, api_key
|
||||
api_base: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
|
@ -288,6 +317,7 @@ def completion(
|
|||
"""
|
||||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600)
|
||||
|
@ -299,7 +329,8 @@ def completion(
|
|||
metadata = kwargs.get('metadata', None)
|
||||
fallbacks = kwargs.get('fallbacks', None)
|
||||
headers = kwargs.get("headers", None)
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("intial_prompt_value", None)
|
||||
|
@ -309,13 +340,17 @@ def completion(
|
|||
eos_token = kwargs.get("eos_token", None)
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
try:
|
||||
if base_url:
|
||||
api_base = base_url
|
||||
if max_retries:
|
||||
num_retries = max_retries
|
||||
logging = litellm_logging_obj
|
||||
fallbacks = (
|
||||
fallbacks
|
||||
|
@ -648,8 +683,11 @@ def completion(
|
|||
response = model_response
|
||||
|
||||
elif custom_llm_provider=="anthropic":
|
||||
anthropic_key = (
|
||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.anthropic_key
|
||||
or litellm.api_key
|
||||
or os.environ.get("ANTHROPIC_API_KEY")
|
||||
)
|
||||
api_base = (
|
||||
api_base
|
||||
|
@ -672,7 +710,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=anthropic_key,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
|
|
34
litellm/tests/test_class.py
Normal file
34
litellm/tests/test_class.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# #### What this tests ####
|
||||
# # This tests the LiteLLM Class
|
||||
|
||||
# import sys, os
|
||||
# import traceback
|
||||
# import pytest
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# litellm.set_verbose = True
|
||||
# from litellm import LiteLLM
|
||||
# import instructor
|
||||
# from pydantic import BaseModel
|
||||
|
||||
# # This enables response_model keyword
|
||||
# # from client.chat.completions.create
|
||||
# client = instructor.patch(LiteLLM())
|
||||
|
||||
# class UserDetail(BaseModel):
|
||||
# name: str
|
||||
# age: int
|
||||
|
||||
# user = client.chat.completions.create(
|
||||
# model="gpt-3.5-turbo",
|
||||
# response_model=UserDetail,
|
||||
# messages=[
|
||||
# {"role": "user", "content": "Extract Jason is 25 years old"},
|
||||
# ]
|
||||
# )
|
||||
|
||||
# assert isinstance(user, UserDetail)
|
||||
# assert user.name == "Jason"
|
||||
# assert user.age == 25
|
|
@ -1,107 +1,107 @@
|
|||
import os
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
# import os
|
||||
# import sys, os
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
# load_dotenv()
|
||||
# import os, io
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, text_completion, completion_cost
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# import litellm
|
||||
# from litellm import embedding, completion, text_completion, completion_cost
|
||||
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
# from langchain.chat_models import ChatLiteLLM
|
||||
# from langchain.prompts.chat import (
|
||||
# ChatPromptTemplate,
|
||||
# SystemMessagePromptTemplate,
|
||||
# AIMessagePromptTemplate,
|
||||
# HumanMessagePromptTemplate,
|
||||
# )
|
||||
# from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
def test_chat_gpt():
|
||||
try:
|
||||
chat = ChatLiteLLM(model="gpt-3.5-turbo", max_tokens=10)
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="what model are you"
|
||||
)
|
||||
]
|
||||
resp = chat(messages)
|
||||
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_chat_gpt()
|
||||
|
||||
|
||||
def test_claude():
|
||||
try:
|
||||
chat = ChatLiteLLM(model="claude-2", max_tokens=10)
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="what model are you"
|
||||
)
|
||||
]
|
||||
resp = chat(messages)
|
||||
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_claude()
|
||||
|
||||
def test_palm():
|
||||
try:
|
||||
chat = ChatLiteLLM(model="palm/chat-bison", max_tokens=10)
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="what model are you"
|
||||
)
|
||||
]
|
||||
resp = chat(messages)
|
||||
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_palm()
|
||||
|
||||
|
||||
# def test_openai_with_params():
|
||||
# def test_chat_gpt():
|
||||
# try:
|
||||
# api_key = os.environ["OPENAI_API_KEY"]
|
||||
# os.environ.pop("OPENAI_API_KEY")
|
||||
# print("testing openai with params")
|
||||
# llm = ChatLiteLLM(
|
||||
# model="gpt-3.5-turbo",
|
||||
# openai_api_key=api_key,
|
||||
# # Prefer using None which is the default value, endpoint could be empty string
|
||||
# openai_api_base= None,
|
||||
# max_tokens=20,
|
||||
# temperature=0.5,
|
||||
# request_timeout=10,
|
||||
# model_kwargs={
|
||||
# "frequency_penalty": 0,
|
||||
# "presence_penalty": 0,
|
||||
# },
|
||||
# verbose=True,
|
||||
# max_retries=0,
|
||||
# )
|
||||
# chat = ChatLiteLLM(model="gpt-3.5-turbo", max_tokens=10)
|
||||
# messages = [
|
||||
# HumanMessage(
|
||||
# content="what model are you"
|
||||
# )
|
||||
# ]
|
||||
# resp = llm(messages)
|
||||
# resp = chat(messages)
|
||||
|
||||
# print(resp)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_openai_with_params()
|
||||
# # test_chat_gpt()
|
||||
|
||||
|
||||
# def test_claude():
|
||||
# try:
|
||||
# chat = ChatLiteLLM(model="claude-2", max_tokens=10)
|
||||
# messages = [
|
||||
# HumanMessage(
|
||||
# content="what model are you"
|
||||
# )
|
||||
# ]
|
||||
# resp = chat(messages)
|
||||
|
||||
# print(resp)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # test_claude()
|
||||
|
||||
# def test_palm():
|
||||
# try:
|
||||
# chat = ChatLiteLLM(model="palm/chat-bison", max_tokens=10)
|
||||
# messages = [
|
||||
# HumanMessage(
|
||||
# content="what model are you"
|
||||
# )
|
||||
# ]
|
||||
# resp = chat(messages)
|
||||
|
||||
# print(resp)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # test_palm()
|
||||
|
||||
|
||||
# # def test_openai_with_params():
|
||||
# # try:
|
||||
# # api_key = os.environ["OPENAI_API_KEY"]
|
||||
# # os.environ.pop("OPENAI_API_KEY")
|
||||
# # print("testing openai with params")
|
||||
# # llm = ChatLiteLLM(
|
||||
# # model="gpt-3.5-turbo",
|
||||
# # openai_api_key=api_key,
|
||||
# # # Prefer using None which is the default value, endpoint could be empty string
|
||||
# # openai_api_base= None,
|
||||
# # max_tokens=20,
|
||||
# # temperature=0.5,
|
||||
# # request_timeout=10,
|
||||
# # model_kwargs={
|
||||
# # "frequency_penalty": 0,
|
||||
# # "presence_penalty": 0,
|
||||
# # },
|
||||
# # verbose=True,
|
||||
# # max_retries=0,
|
||||
# # )
|
||||
# # messages = [
|
||||
# # HumanMessage(
|
||||
# # content="what model are you"
|
||||
# # )
|
||||
# # ]
|
||||
# # resp = llm(messages)
|
||||
|
||||
# # print(resp)
|
||||
# # except Exception as e:
|
||||
# # pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # test_openai_with_params()
|
||||
|
||||
|
|
|
@ -166,7 +166,7 @@ def test_completion_cohere_stream():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_cohere_stream()
|
||||
# test_completion_cohere_stream()
|
||||
|
||||
def test_completion_cohere_stream_bad_key():
|
||||
try:
|
||||
|
@ -464,6 +464,7 @@ def test_completion_palm_stream():
|
|||
def test_completion_claude_stream_bad_key():
|
||||
try:
|
||||
litellm.cache = None
|
||||
litellm.set_verbose = True
|
||||
api_key = "bad-key"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -492,7 +493,7 @@ def test_completion_claude_stream_bad_key():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_claude_stream_bad_key()
|
||||
test_completion_claude_stream_bad_key()
|
||||
# test_completion_replicate_stream()
|
||||
|
||||
# def test_completion_vertexai_stream():
|
||||
|
|
|
@ -53,7 +53,7 @@ from .exceptions import (
|
|||
APIError,
|
||||
BudgetExceededError
|
||||
)
|
||||
from typing import cast, List, Dict, Union, Optional, Literal
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, TypedDict, Required
|
||||
from .caching import Cache
|
||||
|
||||
####### ENVIRONMENT VARIABLES ####################
|
||||
|
@ -118,6 +118,10 @@ def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences -
|
|||
return "stop"
|
||||
return finish_reason
|
||||
|
||||
class FunctionCall(OpenAIObject):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
class Message(OpenAIObject):
|
||||
def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, **params):
|
||||
super(Message, self).__init__(**params)
|
||||
|
@ -125,7 +129,7 @@ class Message(OpenAIObject):
|
|||
self.role = role
|
||||
self._logprobs = logprobs
|
||||
if function_call:
|
||||
self.function_call = function_call
|
||||
self.function_call = FunctionCall(**function_call)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
|
@ -922,7 +926,7 @@ class Logging:
|
|||
callback.log_failure_event(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
messages=self.messages,
|
||||
response_obj=result,
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue