fix(main.py): misrouting ollama models to nlp cloud

This commit is contained in:
Krrish Dholakia 2023-11-14 18:55:01 -08:00
parent 465f427465
commit 1738341dcb
5 changed files with 94 additions and 47 deletions

View file

@ -1,7 +1,7 @@
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal
import random, threading, time
import litellm
import litellm, openai
import logging
class Router:
@ -37,7 +37,7 @@ class Router:
self.healthy_deployments: List = self.model_list
if num_retries:
litellm.num_retries = num_retries
self.num_retries = num_retries
self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
@ -131,6 +131,35 @@ class Router:
return item or item[0]
raise ValueError("No models available.")
def function_with_retries(self, *args, **kwargs):
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
retry_info = {"attempts": 0, "final_result": None}
def after_callback(retry_state):
retry_info["attempts"] = retry_state.attempt_number
retry_info["final_result"] = retry_state.outcome.result()
try:
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
if isinstance(original_exception, openai.RateLimitError):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
elif isinstance(original_exception, openai.APIError):
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
return retryer(original_function, *args, **kwargs)
except Exception as e:
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
### COMPLETION + EMBEDDING FUNCTIONS
@ -148,9 +177,6 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
# call via litellm.completion()
# return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
# litellm.set_verbose = True
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
@ -161,10 +187,17 @@ class Router:
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
return await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
try:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
return response
except Exception as e:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return self.function_with_retries(**kwargs)
def text_completion(self,
model: str,