fix(tests): fixing response objects for testing

This commit is contained in:
Krrish Dholakia 2023-11-13 14:38:41 -08:00
parent 9776126c8d
commit 8a3b771e50
6 changed files with 188 additions and 104 deletions

View file

@ -125,11 +125,18 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
return response.iter_lines()
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,

View file

@ -56,7 +56,7 @@ from .llms.azure import AzureChatCompletion
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union
from typing import Callable, List, Optional, Dict, Union, Mapping
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -79,6 +79,35 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################
class LiteLLM:
def __init__(self, *,
api_key=None,
organization: str | None = None,
base_url: str = None,
timeout: Union[float, None] = 600,
max_retries: int | None = litellm.num_retries,
default_headers: Mapping[str, str] | None = None,):
self.params = locals()
self.chat = Chat(self.params)
class Chat():
def __init__(self, params):
self.params = params
self.completions = Completions(self.params)
class Completions():
def __init__(self, params):
self.params = params
def create(self, model, messages, **kwargs):
for k, v in kwargs.items():
self.params[k] = v
response = completion(model=model, messages=messages, **self.params)
return response
async def acompletion(*args, **kwargs):
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -240,7 +269,7 @@ def completion(
deployment_id = None,
# set api_base, api_version, api_key
api_base: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
@ -288,6 +317,7 @@ def completion(
"""
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600)
@ -299,7 +329,8 @@ def completion(
metadata = kwargs.get('metadata', None)
fallbacks = kwargs.get('fallbacks', None)
headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None)
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("intial_prompt_value", None)
@ -309,13 +340,17 @@ def completion(
eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
try:
if base_url:
api_base = base_url
if max_retries:
num_retries = max_retries
logging = litellm_logging_obj
fallbacks = (
fallbacks
@ -648,8 +683,11 @@ def completion(
response = model_response
elif custom_llm_provider=="anthropic":
anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
api_key = (
api_key
or litellm.anthropic_key
or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY")
)
api_base = (
api_base
@ -672,7 +710,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=anthropic_key,
api_key=api_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:

View file

@ -0,0 +1,34 @@
# #### What this tests ####
# # This tests the LiteLLM Class
# import sys, os
# import traceback
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# litellm.set_verbose = True
# from litellm import LiteLLM
# import instructor
# from pydantic import BaseModel
# # This enables response_model keyword
# # from client.chat.completions.create
# client = instructor.patch(LiteLLM())
# class UserDetail(BaseModel):
# name: str
# age: int
# user = client.chat.completions.create(
# model="gpt-3.5-turbo",
# response_model=UserDetail,
# messages=[
# {"role": "user", "content": "Extract Jason is 25 years old"},
# ]
# )
# assert isinstance(user, UserDetail)
# assert user.name == "Jason"
# assert user.age == 25

View file

@ -1,107 +1,107 @@
import os
import sys, os
import traceback
from dotenv import load_dotenv
# import os
# import sys, os
# import traceback
# from dotenv import load_dotenv
load_dotenv()
import os, io
# load_dotenv()
# import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, text_completion, completion_cost
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion, text_completion, completion_cost
from langchain.chat_models import ChatLiteLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import AIMessage, HumanMessage, SystemMessage
# from langchain.chat_models import ChatLiteLLM
# from langchain.prompts.chat import (
# ChatPromptTemplate,
# SystemMessagePromptTemplate,
# AIMessagePromptTemplate,
# HumanMessagePromptTemplate,
# )
# from langchain.schema import AIMessage, HumanMessage, SystemMessage
def test_chat_gpt():
try:
chat = ChatLiteLLM(model="gpt-3.5-turbo", max_tokens=10)
messages = [
HumanMessage(
content="what model are you"
)
]
resp = chat(messages)
print(resp)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_chat_gpt()
def test_claude():
try:
chat = ChatLiteLLM(model="claude-2", max_tokens=10)
messages = [
HumanMessage(
content="what model are you"
)
]
resp = chat(messages)
print(resp)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_claude()
def test_palm():
try:
chat = ChatLiteLLM(model="palm/chat-bison", max_tokens=10)
messages = [
HumanMessage(
content="what model are you"
)
]
resp = chat(messages)
print(resp)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_palm()
# def test_openai_with_params():
# def test_chat_gpt():
# try:
# api_key = os.environ["OPENAI_API_KEY"]
# os.environ.pop("OPENAI_API_KEY")
# print("testing openai with params")
# llm = ChatLiteLLM(
# model="gpt-3.5-turbo",
# openai_api_key=api_key,
# # Prefer using None which is the default value, endpoint could be empty string
# openai_api_base= None,
# max_tokens=20,
# temperature=0.5,
# request_timeout=10,
# model_kwargs={
# "frequency_penalty": 0,
# "presence_penalty": 0,
# },
# verbose=True,
# max_retries=0,
# )
# chat = ChatLiteLLM(model="gpt-3.5-turbo", max_tokens=10)
# messages = [
# HumanMessage(
# content="what model are you"
# )
# ]
# resp = llm(messages)
# resp = chat(messages)
# print(resp)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_openai_with_params()
# # test_chat_gpt()
# def test_claude():
# try:
# chat = ChatLiteLLM(model="claude-2", max_tokens=10)
# messages = [
# HumanMessage(
# content="what model are you"
# )
# ]
# resp = chat(messages)
# print(resp)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_claude()
# def test_palm():
# try:
# chat = ChatLiteLLM(model="palm/chat-bison", max_tokens=10)
# messages = [
# HumanMessage(
# content="what model are you"
# )
# ]
# resp = chat(messages)
# print(resp)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_palm()
# # def test_openai_with_params():
# # try:
# # api_key = os.environ["OPENAI_API_KEY"]
# # os.environ.pop("OPENAI_API_KEY")
# # print("testing openai with params")
# # llm = ChatLiteLLM(
# # model="gpt-3.5-turbo",
# # openai_api_key=api_key,
# # # Prefer using None which is the default value, endpoint could be empty string
# # openai_api_base= None,
# # max_tokens=20,
# # temperature=0.5,
# # request_timeout=10,
# # model_kwargs={
# # "frequency_penalty": 0,
# # "presence_penalty": 0,
# # },
# # verbose=True,
# # max_retries=0,
# # )
# # messages = [
# # HumanMessage(
# # content="what model are you"
# # )
# # ]
# # resp = llm(messages)
# # print(resp)
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # test_openai_with_params()

View file

@ -166,7 +166,7 @@ def test_completion_cohere_stream():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_cohere_stream()
# test_completion_cohere_stream()
def test_completion_cohere_stream_bad_key():
try:
@ -464,6 +464,7 @@ def test_completion_palm_stream():
def test_completion_claude_stream_bad_key():
try:
litellm.cache = None
litellm.set_verbose = True
api_key = "bad-key"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@ -492,7 +493,7 @@ def test_completion_claude_stream_bad_key():
pytest.fail(f"Error occurred: {e}")
# test_completion_claude_stream_bad_key()
test_completion_claude_stream_bad_key()
# test_completion_replicate_stream()
# def test_completion_vertexai_stream():

View file

@ -53,7 +53,7 @@ from .exceptions import (
APIError,
BudgetExceededError
)
from typing import cast, List, Dict, Union, Optional, Literal
from typing import cast, List, Dict, Union, Optional, Literal, TypedDict, Required
from .caching import Cache
####### ENVIRONMENT VARIABLES ####################
@ -118,6 +118,10 @@ def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences -
return "stop"
return finish_reason
class FunctionCall(OpenAIObject):
arguments: str
name: str
class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, **params):
super(Message, self).__init__(**params)
@ -125,7 +129,7 @@ class Message(OpenAIObject):
self.role = role
self._logprobs = logprobs
if function_call:
self.function_call = function_call
self.function_call = FunctionCall(**function_call)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -922,7 +926,7 @@ class Logging:
callback.log_failure_event(
start_time=start_time,
end_time=end_time,
messages=self.messages,
response_obj=result,
kwargs=self.model_call_details,
)
except Exception as e: