fix(router.py): add modelgroup to call metadata

This commit is contained in:
Krrish Dholakia 2023-11-23 20:55:49 -08:00
parent db8ed601b5
commit 187403c5cc
3 changed files with 71 additions and 131 deletions

View file

@ -115,7 +115,7 @@ class Router:
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -143,6 +143,7 @@ class Router:
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
kwargs.setdefault("metadata", {}).update({"model_group": model})
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
# Submit the function to the executor with a timeout
future = executor.submit(self.function_with_fallbacks, **kwargs)
@ -180,7 +181,7 @@ class Router:
kwargs["messages"] = messages
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = self.num_retries
kwargs.setdefault("metadata", {}).update({"model_group": model})
# Use asyncio.timeout to enforce the timeout
async with asyncio.timeout(self.timeout): # type: ignore
response = await self.async_function_with_fallbacks(**kwargs)
@ -215,6 +216,7 @@ class Router:
is_async: Optional[bool] = False,
**kwargs):
try:
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
@ -241,6 +243,7 @@ class Router:
is_async: Optional[bool] = False,
**kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM)
kwargs.setdefault("metadata", {}).update({"model_group": model})
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
@ -256,6 +259,7 @@ class Router:
is_async: Optional[bool] = True,
**kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM)
kwargs.setdefault("metadata", {}).update({"model_group": model})
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
@ -420,8 +424,6 @@ class Router:
raise e
raise original_exception
def function_with_retries(self, *args, **kwargs):
"""
Try calling the model 3 times. Shuffle between available deployments.
@ -761,4 +763,6 @@ class Router:
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
raise ValueError("No models available.")
def flush_cache(self):
self.cache.flush_cache()