forked from phoenix/litellm-mirror
fix(tests): fixing response objects for testing
This commit is contained in:
parent
9776126c8d
commit
8a3b771e50
6 changed files with 188 additions and 104 deletions
|
@ -56,7 +56,7 @@ from .llms.azure import AzureChatCompletion
|
|||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict, Union
|
||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -79,6 +79,35 @@ openai_text_completions = OpenAITextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
class LiteLLM:
|
||||
|
||||
def __init__(self, *,
|
||||
api_key=None,
|
||||
organization: str | None = None,
|
||||
base_url: str = None,
|
||||
timeout: Union[float, None] = 600,
|
||||
max_retries: int | None = litellm.num_retries,
|
||||
default_headers: Mapping[str, str] | None = None,):
|
||||
self.params = locals()
|
||||
self.chat = Chat(self.params)
|
||||
|
||||
class Chat():
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.completions = Completions(self.params)
|
||||
|
||||
class Completions():
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
def create(self, model, messages, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
response = completion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
|
||||
async def acompletion(*args, **kwargs):
|
||||
"""
|
||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
|
@ -240,7 +269,7 @@ def completion(
|
|||
deployment_id = None,
|
||||
|
||||
# set api_base, api_version, api_key
|
||||
api_base: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
|
@ -288,6 +317,7 @@ def completion(
|
|||
"""
|
||||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600)
|
||||
|
@ -299,7 +329,8 @@ def completion(
|
|||
metadata = kwargs.get('metadata', None)
|
||||
fallbacks = kwargs.get('fallbacks', None)
|
||||
headers = kwargs.get("headers", None)
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("intial_prompt_value", None)
|
||||
|
@ -309,13 +340,17 @@ def completion(
|
|||
eos_token = kwargs.get("eos_token", None)
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
try:
|
||||
if base_url:
|
||||
api_base = base_url
|
||||
if max_retries:
|
||||
num_retries = max_retries
|
||||
logging = litellm_logging_obj
|
||||
fallbacks = (
|
||||
fallbacks
|
||||
|
@ -648,8 +683,11 @@ def completion(
|
|||
response = model_response
|
||||
|
||||
elif custom_llm_provider=="anthropic":
|
||||
anthropic_key = (
|
||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.anthropic_key
|
||||
or litellm.api_key
|
||||
or os.environ.get("ANTHROPIC_API_KEY")
|
||||
)
|
||||
api_base = (
|
||||
api_base
|
||||
|
@ -672,7 +710,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=anthropic_key,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue