fix(utils.py): fix cached responses - translate dict to objects

This commit is contained in:
Krrish Dholakia 2023-11-10 10:38:20 -08:00
parent 84460b8222
commit a4c9e6bd46
4 changed files with 108 additions and 21 deletions

View file

@ -144,7 +144,11 @@ async def acompletion(*args, **kwargs):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else: else:
# Await normally # Await normally
response = await completion(*args, **kwargs) init_response = completion(*args, **kwargs)
if isinstance(init_response, dict):
response = init_response
else:
response = await init_response
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)

View file

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Literal
import random, threading, time import random, threading, time
import litellm import litellm
import logging import logging
@ -29,12 +29,16 @@ class Router:
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
redis_port: Optional[int] = None, redis_port: Optional[int] = None,
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: bool = False) -> None: cache_responses: bool = False,
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
if model_list: if model_list:
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = [] self.healthy_deployments: List = self.model_list
### HEALTH CHECK THREAD ### - commenting out as further testing required
self._start_health_check_thread() self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
### CACHING ### ### CACHING ###
if redis_host is not None and redis_port is not None and redis_password is not None: if redis_host is not None and redis_port is not None and redis_password is not None:
@ -104,13 +108,15 @@ class Router:
""" """
Returns the deployment with the shortest queue Returns the deployment with the shortest queue
""" """
### COMMENTING OUT AS IT NEEDS FURTHER TESTING
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}") logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
if len(self.healthy_deployments) > 0: if self.routing_strategy == "least-busy":
for item in self.healthy_deployments: if len(self.healthy_deployments) > 0:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability for item in self.healthy_deployments:
return item[0] if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
else: return item[0]
else:
raise ValueError("No models available.")
elif self.routing_strategy == "simple-shuffle":
potential_deployments = [] potential_deployments = []
for item in self.model_list: for item in self.model_list:
if item["model_name"] == model: if item["model_name"] == model:

View file

@ -134,7 +134,7 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
print(results) print(results)
test_multiple_deployments() # test_multiple_deployments()
### FUNCTION CALLING ### FUNCTION CALLING
def test_function_calling(): def test_function_calling():
@ -228,6 +228,7 @@ def test_function_calling():
def test_acompletion_on_router(): def test_acompletion_on_router():
try: try:
litellm.set_verbose = True
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
@ -245,16 +246,69 @@ def test_acompletion_on_router():
] ]
async def get_response(): async def get_response():
router = Router(model_list=model_list) router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages) response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
return response print(f"response1: {response1}")
response = asyncio.run(get_response()) response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response2: {response2}")
assert isinstance(response['choices'][0]['message']['content'], str) assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
asyncio.run(get_response())
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_acompletion_on_router()
def test_function_calling_on_router():
try:
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
function1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT")
)
async def get_response():
messages=[
{
"role": "user",
"content": "what's the weather in boston"
}
],
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1)
print(f"response1: {response1}")
return response
response = asyncio.run(get_response())
assert isinstance(response["choices"][0]["message"]["content"]["function_call"], str)
except Exception as e:
print(f"An exception occurred: {e}")
# test_function_calling_on_router()
def test_aembedding_on_router(): def test_aembedding_on_router():
try: try:

View file

@ -907,6 +907,29 @@ def client(original_function):
# [Non-Blocking Error] # [Non-Blocking Error]
pass pass
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
try:
if response_object is None or model_response_object is None:
raise OpenAIError(status_code=500, message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
except:
OpenAIError(status_code=500, message="Invalid response object.")
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None
@ -932,7 +955,7 @@ def client(original_function):
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching # remove this after deprecating litellm.caching
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}") print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}; litellm.cache: {litellm.cache}")
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
litellm.cache = Cache() litellm.cache = Cache()
@ -945,7 +968,7 @@ def client(original_function):
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result != None:
print_verbose(f"Cache Hit!") print_verbose(f"Cache Hit!")
return cached_result return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
# MODEL CALL # MODEL CALL
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)