fix(utils.py): fix cached responses - translate dict to objects

This commit is contained in:
Krrish Dholakia 2023-11-10 10:38:20 -08:00
parent 84460b8222
commit a4c9e6bd46
4 changed files with 108 additions and 21 deletions

View file

@ -144,7 +144,11 @@ async def acompletion(*args, **kwargs):
response = completion(*args, **kwargs)
else:
# Await normally
response = await completion(*args, **kwargs)
init_response = completion(*args, **kwargs)
if isinstance(init_response, dict):
response = init_response
else:
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)

View file

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Literal
import random, threading, time
import litellm
import logging
@ -29,12 +29,16 @@ class Router:
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
cache_responses: bool = False) -> None:
cache_responses: bool = False,
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
if model_list:
self.set_model_list(model_list)
self.healthy_deployments: List = []
### HEALTH CHECK THREAD ### - commenting out as further testing required
self._start_health_check_thread()
self.healthy_deployments: List = self.model_list
self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
### CACHING ###
if redis_host is not None and redis_port is not None and redis_password is not None:
@ -104,13 +108,15 @@ class Router:
"""
Returns the deployment with the shortest queue
"""
### COMMENTING OUT AS IT NEEDS FURTHER TESTING
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
return item[0]
else:
if self.routing_strategy == "least-busy":
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
return item[0]
else:
raise ValueError("No models available.")
elif self.routing_strategy == "simple-shuffle":
potential_deployments = []
for item in self.model_list:
if item["model_name"] == model:

View file

@ -134,7 +134,7 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
print(results)
test_multiple_deployments()
# test_multiple_deployments()
### FUNCTION CALLING
def test_function_calling():
@ -228,6 +228,7 @@ def test_function_calling():
def test_acompletion_on_router():
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
@ -245,16 +246,69 @@ def test_acompletion_on_router():
]
async def get_response():
router = Router(model_list=model_list)
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
return response
response = asyncio.run(get_response())
assert isinstance(response['choices'][0]['message']['content'], str)
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response2: {response2}")
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
asyncio.run(get_response())
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
test_acompletion_on_router()
def test_function_calling_on_router():
try:
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
function1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT")
)
async def get_response():
messages=[
{
"role": "user",
"content": "what's the weather in boston"
}
],
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1)
print(f"response1: {response1}")
return response
response = asyncio.run(get_response())
assert isinstance(response["choices"][0]["message"]["content"]["function_call"], str)
except Exception as e:
print(f"An exception occurred: {e}")
# test_function_calling_on_router()
def test_aembedding_on_router():
try:

View file

@ -907,6 +907,29 @@ def client(original_function):
# [Non-Blocking Error]
pass
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
try:
if response_object is None or model_response_object is None:
raise OpenAIError(status_code=500, message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
except:
OpenAIError(status_code=500, message="Invalid response object.")
def wrapper(*args, **kwargs):
start_time = datetime.datetime.now()
result = None
@ -932,7 +955,7 @@ def client(original_function):
# [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}")
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}; litellm.cache: {litellm.cache}")
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
litellm.cache = Cache()
@ -945,7 +968,7 @@ def client(original_function):
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
print_verbose(f"Cache Hit!")
return cached_result
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
# MODEL CALL
result = original_function(*args, **kwargs)