fix(main.py): misrouting ollama models to nlp cloud

This commit is contained in:
Krrish Dholakia 2023-11-14 18:55:01 -08:00
parent 465f427465
commit 1738341dcb
5 changed files with 94 additions and 47 deletions

View file

@ -1,5 +1,5 @@
from typing import Optional, Union from typing import Optional, Union
import types import types, time
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
@ -160,6 +160,7 @@ class OpenAIChatCompletion(BaseLLM):
super().__init__() super().__init__()
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
self._aclient_session = self.create_aclient_session() self._aclient_session = self.create_aclient_session()
self._num_retry_httpx_errors = 3 # httpx throws random errors - e.g. ReadError,
def validate_environment(self, api_key): def validate_environment(self, api_key):
headers = { headers = {
@ -169,6 +170,15 @@ class OpenAIChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def _retry_request(self, *args, **kwargs):
self._num_retry_httpx_errors -= 1
time.sleep(1)
original_function = kwargs.pop("original_function")
return original_function(*args, **kwargs)
def completion(self, def completion(self,
model_response: ModelResponse, model_response: ModelResponse,
model: Optional[str]=None, model: Optional[str]=None,
@ -253,8 +263,9 @@ class OpenAIChatCompletion(BaseLLM):
api_base: str, api_base: str,
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
kwargs = locals()
client = self._aclient_session client = self._aclient_session
try:
response = await client.post(api_base, json=data, headers=headers) response = await client.post(api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -262,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except httpx.ReadError or httpx.ReadTimeout:
if self._num_retry_httpx_errors > 0:
kwargs["original_function"] = self.acompletion
return self._retry_request(**kwargs)
else:
raise e
except Exception as e:
if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self, def streaming(self,
logging_obj, logging_obj,

View file

@ -713,7 +713,7 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging) response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
return response return response
response = model_response response = model_response
elif model in litellm.nlp_cloud_models or custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
nlp_cloud_key = ( nlp_cloud_key = (
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
) )
@ -744,7 +744,7 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging) response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
return response return response
response = model_response response = model_response
elif model in litellm.aleph_alpha_models: elif custom_llm_provider == "aleph_alpha":
aleph_alpha_key = ( aleph_alpha_key = (
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
) )
@ -909,7 +909,7 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
api_base = ( api_base = (
api_base api_base
or litellm.api_base or litellm.api_base
@ -969,28 +969,6 @@ def completion(
logging_obj=logging, logging_obj=logging,
acompletion=acompletion acompletion=acompletion
) )
# if headers:
# response = openai.chat.completions.create(
# headers=headers, # type: ignore
# **data, # type: ignore
# )
# else:
# openrouter_site_url = get_secret("OR_SITE_URL")
# openrouter_app_name = get_secret("OR_APP_NAME")
# # if openrouter_site_url is None, set it to https://litellm.ai
# if openrouter_site_url is None:
# openrouter_site_url = "https://litellm.ai"
# # if openrouter_app_name is None, set it to liteLLM
# if openrouter_app_name is None:
# openrouter_app_name = "liteLLM"
# response = openai.chat.completions.create( # type: ignore
# extra_headers=httpx.Headers({ # type: ignore
# "HTTP-Referer": openrouter_site_url, # type: ignore
# "X-Title": openrouter_app_name, # type: ignore
# }), # type: ignore
# **data,
# )
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, api_key=openai.api_key, original_response=response input=messages, api_key=openai.api_key, original_response=response
@ -1093,7 +1071,7 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif model in litellm.ai21_models: elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21" custom_llm_provider = "ai21"
ai21_key = ( ai21_key = (
api_key api_key
@ -1233,7 +1211,6 @@ def completion(
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING ## LOGGING
if kwargs.get('acompletion', False) == True: if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:

View file

@ -113,7 +113,6 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
print("\033[1;32mDone successfully\033[0m") print("\033[1;32mDone successfully\033[0m")
return return
if model and "ollama" in model: if model and "ollama" in model:
print(f"ollama called")
run_ollama_serve() run_ollama_serve()
if test != False: if test != False:
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy') click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy')

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal from typing import Dict, List, Optional, Union, Literal
import random, threading, time import random, threading, time
import litellm import litellm, openai
import logging import logging
class Router: class Router:
@ -37,7 +37,7 @@ class Router:
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list
if num_retries: if num_retries:
litellm.num_retries = num_retries self.num_retries = num_retries
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ### ### HEALTH CHECK THREAD ###
@ -132,6 +132,35 @@ class Router:
raise ValueError("No models available.") raise ValueError("No models available.")
def function_with_retries(self, *args, **kwargs):
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
retry_info = {"attempts": 0, "final_result": None}
def after_callback(retry_state):
retry_info["attempts"] = retry_state.attempt_number
retry_info["final_result"] = retry_state.outcome.result()
try:
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
if isinstance(original_exception, openai.RateLimitError):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
elif isinstance(original_exception, openai.APIError):
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(self.num_retries),
reraise=True,
after=after_callback)
return retryer(original_function, *args, **kwargs)
except Exception as e:
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
### COMPLETION + EMBEDDING FUNCTIONS ### COMPLETION + EMBEDDING FUNCTIONS
def completion(self, def completion(self,
@ -148,9 +177,6 @@ class Router:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"] data = deployment["litellm_params"]
# call via litellm.completion()
# return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
# litellm.set_verbose = True
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
@ -161,10 +187,17 @@ class Router:
is_retry: Optional[bool] = False, is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False, is_fallback: Optional[bool] = False,
**kwargs): **kwargs):
# pick the one that is available (lowest TPM/RPM) try:
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"] data = deployment["litellm_params"]
return await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
return response
except Exception as e:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return self.function_with_retries(**kwargs)
def text_completion(self, def text_completion(self,
model: str, model: str,

View file

@ -25,6 +25,22 @@ async def main():
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}] }]
router = Router(model_list=model_list, num_retries=3) router = Router(model_list=model_list, num_retries=3)
@ -35,13 +51,13 @@ async def main():
tasks = [] tasks = []
# Launch 1000 tasks # Launch 1000 tasks
for _ in range(1000): for _ in range(100):
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]})) task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
tasks.append(task) tasks.append(task)
# Wait for all tasks to complete # Wait for all tasks to complete
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
# Process responses as needed # Process responses as needed
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
# Run the main function # Run the main function
asyncio.run(main()) asyncio.run(main())