diff --git a/litellm/main.py b/litellm/main.py index 817dc5510..3e875815e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -116,24 +116,54 @@ class LiteLLM: default_headers: Optional[Mapping[str, str]] = None, ): self.params = locals() - self.chat = Chat(self.params) + self.chat = Chat(self.params, router_obj=None) class Chat: - def __init__(self, params): + def __init__(self, params, router_obj: Optional[Any]): self.params = params - self.completions = Completions(self.params) + if self.params.get("acompletion", False) == True: + self.params.pop("acompletion") + self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions( + self.params, router_obj=router_obj + ) + else: + self.completions = Completions(self.params, router_obj=router_obj) class Completions: - def __init__(self, params): + def __init__(self, params, router_obj: Optional[Any]): self.params = params + self.router_obj = router_obj def create(self, messages, model=None, **kwargs): for k, v in kwargs.items(): self.params[k] = v model = model or self.params.get("model") - response = completion(model=model, messages=messages, **self.params) + if self.router_obj is not None: + response = self.router_obj.completion( + model=model, messages=messages, **self.params + ) + else: + response = completion(model=model, messages=messages, **self.params) + return response + + +class AsyncCompletions: + def __init__(self, params, router_obj: Optional[Any]): + self.params = params + self.router_obj = router_obj + + async def create(self, messages, model=None, **kwargs): + for k, v in kwargs.items(): + self.params[k] = v + model = model or self.params.get("model") + if self.router_obj is not None: + response = await self.router_obj.acompletion( + model=model, messages=messages, **self.params + ) + else: + response = await acompletion(model=model, messages=messages, **self.params) return response diff --git a/litellm/router.py b/litellm/router.py index b39b67a09..7bcaf7faf 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -230,7 +230,7 @@ class Router: ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group # make Router.chat.completions.create compatible for openai.chat.completions.create - self.chat = litellm.Chat(params=default_litellm_params) + self.chat = litellm.Chat(params=default_litellm_params, router_obj=self) # default litellm args self.default_litellm_params = default_litellm_params diff --git a/litellm/tests/test_class.py b/litellm/tests/test_class.py index 3520d870d..7f1fc9065 100644 --- a/litellm/tests/test_class.py +++ b/litellm/tests/test_class.py @@ -4,6 +4,7 @@ # import sys, os # import traceback # import pytest + # sys.path.insert( # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path @@ -16,51 +17,68 @@ # from pydantic import BaseModel # # This enables response_model keyword -# # # from client.chat.completions.create -# # client = instructor.patch(Router(model_list=[{ -# # "model_name": "gpt-3.5-turbo", # openai model name -# # "litellm_params": { # params for litellm completion/embedding call -# # "model": "azure/chatgpt-v-2", -# # "api_key": os.getenv("AZURE_API_KEY"), -# # "api_version": os.getenv("AZURE_API_VERSION"), -# # "api_base": os.getenv("AZURE_API_BASE") -# # } -# # }])) +# # from client.chat.completions.create +# client = instructor.patch( +# Router( +# model_list=[ +# { +# "model_name": "gpt-3.5-turbo", # openai model name +# "litellm_params": { # params for litellm completion/embedding call +# "model": "azure/chatgpt-v-2", +# "api_key": os.getenv("AZURE_API_KEY"), +# "api_version": os.getenv("AZURE_API_VERSION"), +# "api_base": os.getenv("AZURE_API_BASE"), +# }, +# } +# ] +# ) +# ) -# # class UserDetail(BaseModel): -# # name: str -# # age: int -# # user = client.chat.completions.create( -# # model="gpt-3.5-turbo", -# # response_model=UserDetail, -# # messages=[ -# # {"role": "user", "content": "Extract Jason is 25 years old"}, -# # ] -# # ) -# # assert isinstance(model, UserExtract) +# class UserDetail(BaseModel): +# name: str +# age: int -# # assert isinstance(user, UserDetail) -# # assert user.name == "Jason" -# # assert user.age == 25 -# # print(f"user: {user}") -# import instructor -# from openai import AsyncOpenAI +# user = client.chat.completions.create( +# model="gpt-3.5-turbo", +# response_model=UserDetail, +# messages=[ +# {"role": "user", "content": "Extract Jason is 25 years old"}, +# ], +# ) + +# assert isinstance(user, UserDetail) +# assert user.name == "Jason" +# assert user.age == 25 + +# print(f"user: {user}") +# # import instructor +# # from openai import AsyncOpenAI + +# aclient = instructor.apatch( +# Router( +# model_list=[ +# { +# "model_name": "gpt-3.5-turbo", # openai model name +# "litellm_params": { # params for litellm completion/embedding call +# "model": "azure/chatgpt-v-2", +# "api_key": os.getenv("AZURE_API_KEY"), +# "api_version": os.getenv("AZURE_API_VERSION"), +# "api_base": os.getenv("AZURE_API_BASE"), +# }, +# } +# ], +# default_litellm_params={"acompletion": True}, +# ) +# ) -# aclient = instructor.apatch(Router(model_list=[{ -# "model_name": "gpt-3.5-turbo", # openai model name -# "litellm_params": { # params for litellm completion/embedding call -# "model": "azure/chatgpt-v-2", -# "api_key": os.getenv("AZURE_API_KEY"), -# "api_version": os.getenv("AZURE_API_VERSION"), -# "api_base": os.getenv("AZURE_API_BASE") -# } -# }], default_litellm_params={"acompletion": True})) # class UserExtract(BaseModel): # name: str # age: int + + # async def main(): # model = await aclient.chat.completions.create( # model="gpt-3.5-turbo", @@ -71,4 +89,5 @@ # ) # print(f"model: {model}") + # asyncio.run(main())