mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): fix cached responses - translate dict to objects
This commit is contained in:
parent
84460b8222
commit
a4c9e6bd46
4 changed files with 108 additions and 21 deletions
|
@ -144,7 +144,11 @@ async def acompletion(*args, **kwargs):
|
||||||
response = completion(*args, **kwargs)
|
response = completion(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# Await normally
|
# Await normally
|
||||||
response = await completion(*args, **kwargs)
|
init_response = completion(*args, **kwargs)
|
||||||
|
if isinstance(init_response, dict):
|
||||||
|
response = init_response
|
||||||
|
else:
|
||||||
|
response = await init_response
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Literal
|
||||||
import random, threading, time
|
import random, threading, time
|
||||||
import litellm
|
import litellm
|
||||||
import logging
|
import logging
|
||||||
|
@ -29,12 +29,16 @@ class Router:
|
||||||
redis_host: Optional[str] = None,
|
redis_host: Optional[str] = None,
|
||||||
redis_port: Optional[int] = None,
|
redis_port: Optional[int] = None,
|
||||||
redis_password: Optional[str] = None,
|
redis_password: Optional[str] = None,
|
||||||
cache_responses: bool = False) -> None:
|
cache_responses: bool = False,
|
||||||
|
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||||
if model_list:
|
if model_list:
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
self.healthy_deployments: List = []
|
self.healthy_deployments: List = self.model_list
|
||||||
### HEALTH CHECK THREAD ### - commenting out as further testing required
|
|
||||||
self._start_health_check_thread()
|
self.routing_strategy = routing_strategy
|
||||||
|
### HEALTH CHECK THREAD ###
|
||||||
|
if self.routing_strategy == "least-busy":
|
||||||
|
self._start_health_check_thread()
|
||||||
|
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||||
|
@ -104,13 +108,15 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Returns the deployment with the shortest queue
|
Returns the deployment with the shortest queue
|
||||||
"""
|
"""
|
||||||
### COMMENTING OUT AS IT NEEDS FURTHER TESTING
|
|
||||||
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
|
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
|
||||||
if len(self.healthy_deployments) > 0:
|
if self.routing_strategy == "least-busy":
|
||||||
for item in self.healthy_deployments:
|
if len(self.healthy_deployments) > 0:
|
||||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
for item in self.healthy_deployments:
|
||||||
return item[0]
|
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||||
else:
|
return item[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("No models available.")
|
||||||
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
potential_deployments = []
|
potential_deployments = []
|
||||||
for item in self.model_list:
|
for item in self.model_list:
|
||||||
if item["model_name"] == model:
|
if item["model_name"] == model:
|
||||||
|
|
|
@ -134,7 +134,7 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
|
|
||||||
test_multiple_deployments()
|
# test_multiple_deployments()
|
||||||
### FUNCTION CALLING
|
### FUNCTION CALLING
|
||||||
|
|
||||||
def test_function_calling():
|
def test_function_calling():
|
||||||
|
@ -228,6 +228,7 @@ def test_function_calling():
|
||||||
|
|
||||||
def test_acompletion_on_router():
|
def test_acompletion_on_router():
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
@ -245,16 +246,69 @@ def test_acompletion_on_router():
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_response():
|
async def get_response():
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
|
||||||
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
return response
|
print(f"response1: {response1}")
|
||||||
response = asyncio.run(get_response())
|
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
print(f"response2: {response2}")
|
||||||
assert isinstance(response['choices'][0]['message']['content'], str)
|
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
|
||||||
|
asyncio.run(get_response())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_acompletion_on_router()
|
||||||
|
|
||||||
|
def test_function_calling_on_router():
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
function1 = [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
redis_host=os.getenv("REDIS_HOST"),
|
||||||
|
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||||
|
redis_port=os.getenv("REDIS_PORT")
|
||||||
|
)
|
||||||
|
async def get_response():
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what's the weather in boston"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
return response
|
||||||
|
response = asyncio.run(get_response())
|
||||||
|
assert isinstance(response["choices"][0]["message"]["content"]["function_call"], str)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
# test_function_calling_on_router()
|
||||||
|
|
||||||
def test_aembedding_on_router():
|
def test_aembedding_on_router():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -907,6 +907,29 @@ def client(original_function):
|
||||||
# [Non-Blocking Error]
|
# [Non-Blocking Error]
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
|
||||||
|
try:
|
||||||
|
if response_object is None or model_response_object is None:
|
||||||
|
raise OpenAIError(status_code=500, message="Error in response object format")
|
||||||
|
choice_list=[]
|
||||||
|
for idx, choice in enumerate(response_object["choices"]):
|
||||||
|
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
|
||||||
|
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
|
||||||
|
choice_list.append(choice)
|
||||||
|
model_response_object.choices = choice_list
|
||||||
|
|
||||||
|
if "usage" in response_object:
|
||||||
|
model_response_object.usage = response_object["usage"]
|
||||||
|
|
||||||
|
if "id" in response_object:
|
||||||
|
model_response_object.id = response_object["id"]
|
||||||
|
|
||||||
|
if "model" in response_object:
|
||||||
|
model_response_object.model = response_object["model"]
|
||||||
|
return model_response_object
|
||||||
|
except:
|
||||||
|
OpenAIError(status_code=500, message="Invalid response object.")
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
result = None
|
result = None
|
||||||
|
@ -932,7 +955,7 @@ def client(original_function):
|
||||||
|
|
||||||
# [OPTIONAL] CHECK CACHE
|
# [OPTIONAL] CHECK CACHE
|
||||||
# remove this after deprecating litellm.caching
|
# remove this after deprecating litellm.caching
|
||||||
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}")
|
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}; litellm.cache: {litellm.cache}")
|
||||||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
|
|
||||||
|
@ -945,7 +968,7 @@ def client(original_function):
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
if cached_result != None:
|
if cached_result != None:
|
||||||
print_verbose(f"Cache Hit!")
|
print_verbose(f"Cache Hit!")
|
||||||
return cached_result
|
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||||
|
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue