diff --git a/litellm/router.py b/litellm/router.py index f2e42068d..ce66028b5 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -34,6 +34,7 @@ class Router: cache_responses: bool = False, num_retries: Optional[int] = None, timeout: float = 600, + chat_completion_params = {}, # default params for Router.chat.completion.create routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: if model_list: @@ -42,6 +43,8 @@ class Router: if num_retries: self.num_retries = num_retries + + self.chat = litellm.Chat(params=chat_completion_params) litellm.request_timeout = timeout self.routing_strategy = routing_strategy @@ -65,7 +68,6 @@ class Router: litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests self.cache_responses = cache_responses - self.chat = litellm.Chat(params={}) def _start_health_check_thread(self): diff --git a/litellm/tests/test_class.py b/litellm/tests/test_class.py index 909c8e939..6a814155f 100644 --- a/litellm/tests/test_class.py +++ b/litellm/tests/test_class.py @@ -1,26 +1,54 @@ -#### What this tests #### -# This tests the LiteLLM Class +# #### What this tests #### +# # This tests the LiteLLM Class -import sys, os -import traceback -import pytest -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import litellm +# import sys, os +# import traceback +# import pytest +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# import asyncio -mr1 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo") -mr1.choices[0].finish_reason = "stop" -mr2 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo") -print(mr2.choices[0].finish_reason) # litellm.set_verbose = True # from litellm import Router # import instructor # from pydantic import BaseModel # # This enables response_model keyword -# # from client.chat.completions.create -# client = instructor.patch(Router(model_list=[{ +# # # from client.chat.completions.create +# # client = instructor.patch(Router(model_list=[{ +# # "model_name": "gpt-3.5-turbo", # openai model name +# # "litellm_params": { # params for litellm completion/embedding call +# # "model": "azure/chatgpt-v-2", +# # "api_key": os.getenv("AZURE_API_KEY"), +# # "api_version": os.getenv("AZURE_API_VERSION"), +# # "api_base": os.getenv("AZURE_API_BASE") +# # } +# # }])) + +# # class UserDetail(BaseModel): +# # name: str +# # age: int + +# # user = client.chat.completions.create( +# # model="gpt-3.5-turbo", +# # response_model=UserDetail, +# # messages=[ +# # {"role": "user", "content": "Extract Jason is 25 years old"}, +# # ] +# # ) +# # assert isinstance(model, UserExtract) + +# # assert isinstance(user, UserDetail) +# # assert user.name == "Jason" +# # assert user.age == 25 + +# # print(f"user: {user}") +# import instructor +# from openai import AsyncOpenAI + +# aclient = instructor.apatch(Router(model_list=[{ # "model_name": "gpt-3.5-turbo", # openai model name # "litellm_params": { # params for litellm completion/embedding call # "model": "azure/chatgpt-v-2", @@ -28,22 +56,19 @@ print(mr2.choices[0].finish_reason) # "api_version": os.getenv("AZURE_API_VERSION"), # "api_base": os.getenv("AZURE_API_BASE") # } -# }])) +# }], chat_completion_params={"acompletion": True})) -# class UserDetail(BaseModel): +# class UserExtract(BaseModel): # name: str # age: int +# async def main(): +# model = await aclient.chat.completions.create( +# model="gpt-3.5-turbo", +# response_model=UserExtract, +# messages=[ +# {"role": "user", "content": "Extract jason is 25 years old"}, +# ], +# ) +# print(f"model: {model}") -# user = client.chat.completions.create( -# model="gpt-3.5-turbo", -# response_model=UserDetail, -# messages=[ -# {"role": "user", "content": "Extract Jason is 25 years old"}, -# ] -# ) - -# assert isinstance(user, UserDetail) -# assert user.name == "Jason" -# assert user.age == 25 - -# print(f"user: {user}") \ No newline at end of file +# asyncio.run(main()) \ No newline at end of file