feat(router.py): enable passing chat completion params for Router.chat.completion.create

This commit is contained in:
Krrish Dholakia 2023-11-15 12:28:09 -08:00
parent 95f9c6779d
commit 03303033e5
2 changed files with 58 additions and 31 deletions

View file

@ -34,6 +34,7 @@ class Router:
cache_responses: bool = False, cache_responses: bool = False,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
timeout: float = 600, timeout: float = 600,
chat_completion_params = {}, # default params for Router.chat.completion.create
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
if model_list: if model_list:
@ -43,6 +44,8 @@ class Router:
if num_retries: if num_retries:
self.num_retries = num_retries self.num_retries = num_retries
self.chat = litellm.Chat(params=chat_completion_params)
litellm.request_timeout = timeout litellm.request_timeout = timeout
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ### ### HEALTH CHECK THREAD ###
@ -65,7 +68,6 @@ class Router:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses self.cache_responses = cache_responses
self.chat = litellm.Chat(params={})
def _start_health_check_thread(self): def _start_health_check_thread(self):

View file

@ -1,26 +1,54 @@
#### What this tests #### # #### What this tests ####
# This tests the LiteLLM Class # # This tests the LiteLLM Class
import sys, os # import sys, os
import traceback # import traceback
import pytest # import pytest
sys.path.insert( # sys.path.insert(
0, os.path.abspath("../..") # 0, os.path.abspath("../..")
) # Adds the parent directory to the system path # ) # Adds the parent directory to the system path
import litellm # import litellm
# import asyncio
mr1 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo")
mr1.choices[0].finish_reason = "stop"
mr2 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo")
print(mr2.choices[0].finish_reason)
# litellm.set_verbose = True # litellm.set_verbose = True
# from litellm import Router # from litellm import Router
# import instructor # import instructor
# from pydantic import BaseModel # from pydantic import BaseModel
# # This enables response_model keyword # # This enables response_model keyword
# # from client.chat.completions.create # # # from client.chat.completions.create
# client = instructor.patch(Router(model_list=[{ # # client = instructor.patch(Router(model_list=[{
# # "model_name": "gpt-3.5-turbo", # openai model name
# # "litellm_params": { # params for litellm completion/embedding call
# # "model": "azure/chatgpt-v-2",
# # "api_key": os.getenv("AZURE_API_KEY"),
# # "api_version": os.getenv("AZURE_API_VERSION"),
# # "api_base": os.getenv("AZURE_API_BASE")
# # }
# # }]))
# # class UserDetail(BaseModel):
# # name: str
# # age: int
# # user = client.chat.completions.create(
# # model="gpt-3.5-turbo",
# # response_model=UserDetail,
# # messages=[
# # {"role": "user", "content": "Extract Jason is 25 years old"},
# # ]
# # )
# # assert isinstance(model, UserExtract)
# # assert isinstance(user, UserDetail)
# # assert user.name == "Jason"
# # assert user.age == 25
# # print(f"user: {user}")
# import instructor
# from openai import AsyncOpenAI
# aclient = instructor.apatch(Router(model_list=[{
# "model_name": "gpt-3.5-turbo", # openai model name # "model_name": "gpt-3.5-turbo", # openai model name
# "litellm_params": { # params for litellm completion/embedding call # "litellm_params": { # params for litellm completion/embedding call
# "model": "azure/chatgpt-v-2", # "model": "azure/chatgpt-v-2",
@ -28,22 +56,19 @@ print(mr2.choices[0].finish_reason)
# "api_version": os.getenv("AZURE_API_VERSION"), # "api_version": os.getenv("AZURE_API_VERSION"),
# "api_base": os.getenv("AZURE_API_BASE") # "api_base": os.getenv("AZURE_API_BASE")
# } # }
# }])) # }], chat_completion_params={"acompletion": True}))
# class UserDetail(BaseModel): # class UserExtract(BaseModel):
# name: str # name: str
# age: int # age: int
# async def main():
# user = client.chat.completions.create( # model = await aclient.chat.completions.create(
# model="gpt-3.5-turbo", # model="gpt-3.5-turbo",
# response_model=UserDetail, # response_model=UserExtract,
# messages=[ # messages=[
# {"role": "user", "content": "Extract Jason is 25 years old"}, # {"role": "user", "content": "Extract jason is 25 years old"},
# ] # ],
# ) # )
# print(f"model: {model}")
# assert isinstance(user, UserDetail) # asyncio.run(main())
# assert user.name == "Jason"
# assert user.age == 25
# print(f"user: {user}")