fix(tests): fixing response objects for testing

This commit is contained in:
Krrish Dholakia 2023-11-13 14:38:41 -08:00
parent 9776126c8d
commit 8a3b771e50
6 changed files with 188 additions and 104 deletions

View file

@ -56,7 +56,7 @@ from .llms.azure import AzureChatCompletion
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union
from typing import Callable, List, Optional, Dict, Union, Mapping
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -79,6 +79,35 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################
class LiteLLM:
def __init__(self, *,
api_key=None,
organization: str | None = None,
base_url: str = None,
timeout: Union[float, None] = 600,
max_retries: int | None = litellm.num_retries,
default_headers: Mapping[str, str] | None = None,):
self.params = locals()
self.chat = Chat(self.params)
class Chat():
def __init__(self, params):
self.params = params
self.completions = Completions(self.params)
class Completions():
def __init__(self, params):
self.params = params
def create(self, model, messages, **kwargs):
for k, v in kwargs.items():
self.params[k] = v
response = completion(model=model, messages=messages, **self.params)
return response
async def acompletion(*args, **kwargs):
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -240,7 +269,7 @@ def completion(
deployment_id = None,
# set api_base, api_version, api_key
api_base: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
@ -288,6 +317,7 @@ def completion(
"""
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600)
@ -299,7 +329,8 @@ def completion(
metadata = kwargs.get('metadata', None)
fallbacks = kwargs.get('fallbacks', None)
headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None)
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("intial_prompt_value", None)
@ -309,13 +340,17 @@ def completion(
eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
try:
if base_url:
api_base = base_url
if max_retries:
num_retries = max_retries
logging = litellm_logging_obj
fallbacks = (
fallbacks
@ -648,8 +683,11 @@ def completion(
response = model_response
elif custom_llm_provider=="anthropic":
anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
api_key = (
api_key
or litellm.anthropic_key
or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY")
)
api_base = (
api_base
@ -672,7 +710,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=anthropic_key,
api_key=api_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True: