fix(router.py): deepcopy initial model list, don't mutate it

This commit is contained in:
Krrish Dholakia 2023-12-12 09:53:35 -08:00
parent 5e9286ed41
commit 0cf0c2d6dd
6 changed files with 280 additions and 102 deletions

View file

@ -1,3 +1,4 @@
from tkinter import N
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, time, json import types, time, json
import httpx import httpx
@ -195,23 +196,23 @@ class OpenAIChatCompletion(BaseLLM):
**optional_params **optional_params
} }
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
)
try: try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else: else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else: else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
if client is None: if client is None:
@ -260,6 +261,8 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None, api_base: Optional[str]=None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None,
headers=None
): ):
response = None response = None
try: try:
@ -267,8 +270,21 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) stringified_response = response.model_dump_json()
logging_obj.post_call(
input=data['messages'],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e: except Exception as e:
if response and hasattr(response, "text"): if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
@ -286,12 +302,19 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None, api_base: Optional[str]=None,
client = None, client = None,
max_retries=None max_retries=None,
headers=None
): ):
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_client = client openai_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
)
response = openai_client.chat.completions.create(**data) response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
return streamwrapper return streamwrapper
@ -305,6 +328,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None, api_base: Optional[str]=None,
client=None, client=None,
max_retries=None, max_retries=None,
headers=None
): ):
response = None response = None
try: try:
@ -312,6 +336,13 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:

View file

@ -607,7 +607,7 @@ def completion(
) )
raise e raise e
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False):
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,

View file

@ -7,6 +7,7 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
@ -879,7 +880,7 @@ class Router:
return chosen_item return chosen_item
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
self.model_list = model_list self.model_list = copy.deepcopy(model_list)
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os import os
for model in self.model_list: for model in self.model_list:

View file

@ -1,5 +1,5 @@
Task exception was never retrieved Task exception was never retrieved
future: <Task finished name='Task-334' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')> future: <Task finished name='Task-336' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')>
Traceback (most recent call last): Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose
await self._close_session() await self._close_session()

View file

@ -61,3 +61,9 @@ model_list:
description: this is a test openai model description: this is a test openai model
id: 34339b1e-e030-4bcc-a531-c48559f10ce4 id: 34339b1e-e030-4bcc-a531-c48559f10ce4
model_name: test_openai_models model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: f6f74e14-ac64-4403-9365-319e584dcdc5
model_name: test_openai_models

View file

@ -21,10 +21,14 @@ class MyCustomHandler(CustomLogger):
print(f"Pre-API Call") print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print(f"On Stream")
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}") print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
@ -41,67 +45,65 @@ class MyCustomHandler(CustomLogger):
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure") print(f"On Failure")
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]} kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
def test_sync_fallbacks(): def test_sync_fallbacks():
try: try:
print("Test router_fallbacks: test_sync_fallbacks()") model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
litellm.set_verbose = True litellm.set_verbose = True
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
@ -123,6 +125,60 @@ def test_sync_fallbacks():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_fallbacks(): async def test_async_fallbacks():
litellm.set_verbose = False litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
router = Router(model_list=model_list, router = Router(model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
@ -146,30 +202,6 @@ async def test_async_fallbacks():
# test_async_fallbacks() # test_async_fallbacks()
## COMMENTING OUT as the context size exceeds both gpt-3.5-turbo and gpt-3.5-turbo-16k, need a better message here
# def test_sync_context_window_fallbacks():
# try:
# customHandler = MyCustomHandler()
# litellm.callbacks = [customHandler]
# sample_text = "Say error 50 times" * 10000
# kwargs["model"] = "azure/gpt-3.5-turbo-context-fallback"
# kwargs["messages"] = [{"role": "user", "content": sample_text}]
# router = Router(model_list=model_list,
# fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
# context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
# set_verbose=False)
# response = router.completion(**kwargs)
# print(f"response: {response}")
# time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
# assert customHandler.previous_models == 1 # 0 retries, 1 fallback
# router.reset()
# except Exception as e:
# print(f"An exception occurred - {e}")
# finally:
# router.reset()
# test_sync_context_window_fallbacks()
def test_dynamic_fallbacks_sync(): def test_dynamic_fallbacks_sync():
""" """
Allow setting the fallback in the router.completion() call. Allow setting the fallback in the router.completion() call.
@ -177,6 +209,60 @@ def test_dynamic_fallbacks_sync():
try: try:
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
router = Router(model_list=model_list, set_verbose=True) router = Router(model_list=model_list, set_verbose=True)
kwargs = {} kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo" kwargs["model"] = "azure/gpt-3.5-turbo"
@ -198,11 +284,65 @@ async def test_dynamic_fallbacks_async():
Allow setting the fallback in the router.completion() call. Allow setting the fallback in the router.completion() call.
""" """
try: try:
print("Router - test_dynamic_fallbacks_async") model_list = [
print("Callbacks in test_dynamic_fallbacks_async: ", litellm.callbacks) { # list of model deployments
print("Success callbacks in test_dynamic_fallbacks_async: ", litellm.success_callback) "model_name": "azure/gpt-3.5-turbo", # openai model name
print("Async Success callbacks in test_dynamic_fallbacks_async: ", litellm._async_success_callback) "litellm_params": { # params for litellm completion/embedding call
litellm.set_verbose=True "model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
print()
print()
print()
print()
print(f"STARTING DYNAMIC ASYNC")
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
router = Router(model_list=model_list, set_verbose=True) router = Router(model_list=model_list, set_verbose=True)
@ -217,4 +357,4 @@ async def test_dynamic_fallbacks_async():
router.reset() router.reset()
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
# test_dynamic_fallbacks_async() # asyncio.run(test_dynamic_fallbacks_async())