mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(utils.py): fix cached responses - translate dict to objects
This commit is contained in:
parent
84460b8222
commit
a4c9e6bd46
4 changed files with 108 additions and 21 deletions
|
@ -144,7 +144,11 @@ async def acompletion(*args, **kwargs):
|
|||
response = completion(*args, **kwargs)
|
||||
else:
|
||||
# Await normally
|
||||
response = await completion(*args, **kwargs)
|
||||
init_response = completion(*args, **kwargs)
|
||||
if isinstance(init_response, dict):
|
||||
response = init_response
|
||||
else:
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time
|
||||
import litellm
|
||||
import logging
|
||||
|
@ -29,12 +29,16 @@ class Router:
|
|||
redis_host: Optional[str] = None,
|
||||
redis_port: Optional[int] = None,
|
||||
redis_password: Optional[str] = None,
|
||||
cache_responses: bool = False) -> None:
|
||||
cache_responses: bool = False,
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||
if model_list:
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = []
|
||||
### HEALTH CHECK THREAD ### - commenting out as further testing required
|
||||
self._start_health_check_thread()
|
||||
self.healthy_deployments: List = self.model_list
|
||||
|
||||
self.routing_strategy = routing_strategy
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
self._start_health_check_thread()
|
||||
|
||||
### CACHING ###
|
||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||
|
@ -104,13 +108,15 @@ class Router:
|
|||
"""
|
||||
Returns the deployment with the shortest queue
|
||||
"""
|
||||
### COMMENTING OUT AS IT NEEDS FURTHER TESTING
|
||||
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||
return item[0]
|
||||
else:
|
||||
if self.routing_strategy == "least-busy":
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||
return item[0]
|
||||
else:
|
||||
raise ValueError("No models available.")
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
potential_deployments = []
|
||||
for item in self.model_list:
|
||||
if item["model_name"] == model:
|
||||
|
|
|
@ -134,7 +134,7 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
|
|||
print(results)
|
||||
|
||||
|
||||
test_multiple_deployments()
|
||||
# test_multiple_deployments()
|
||||
### FUNCTION CALLING
|
||||
|
||||
def test_function_calling():
|
||||
|
@ -228,6 +228,7 @@ def test_function_calling():
|
|||
|
||||
def test_acompletion_on_router():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
|
@ -245,16 +246,69 @@ def test_acompletion_on_router():
|
|||
]
|
||||
|
||||
async def get_response():
|
||||
router = Router(model_list=model_list)
|
||||
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
return response
|
||||
response = asyncio.run(get_response())
|
||||
|
||||
assert isinstance(response['choices'][0]['message']['content'], str)
|
||||
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response1: {response1}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response2: {response2}")
|
||||
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
|
||||
asyncio.run(get_response())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_acompletion_on_router()
|
||||
|
||||
def test_function_calling_on_router():
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
function1 = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=os.getenv("REDIS_PORT")
|
||||
)
|
||||
async def get_response():
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what's the weather in boston"
|
||||
}
|
||||
],
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1)
|
||||
print(f"response1: {response1}")
|
||||
return response
|
||||
response = asyncio.run(get_response())
|
||||
assert isinstance(response["choices"][0]["message"]["content"]["function_call"], str)
|
||||
except Exception as e:
|
||||
print(f"An exception occurred: {e}")
|
||||
|
||||
# test_function_calling_on_router()
|
||||
|
||||
def test_aembedding_on_router():
|
||||
try:
|
||||
|
|
|
@ -907,6 +907,29 @@ def client(original_function):
|
|||
# [Non-Blocking Error]
|
||||
pass
|
||||
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
|
||||
try:
|
||||
if response_object is None or model_response_object is None:
|
||||
raise OpenAIError(status_code=500, message="Error in response object format")
|
||||
choice_list=[]
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
|
||||
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
model_response_object.usage = response_object["usage"]
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
return model_response_object
|
||||
except:
|
||||
OpenAIError(status_code=500, message="Invalid response object.")
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
|
@ -932,7 +955,7 @@ def client(original_function):
|
|||
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
# remove this after deprecating litellm.caching
|
||||
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}")
|
||||
print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}; litellm.cache: {litellm.cache}")
|
||||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||
litellm.cache = Cache()
|
||||
|
||||
|
@ -945,7 +968,7 @@ def client(original_function):
|
|||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
print_verbose(f"Cache Hit!")
|
||||
return cached_result
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
|
||||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue