mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(main.py): support router.chat.completions.create
allows using router with instructor https://github.com/BerriAI/litellm/issues/2673
This commit is contained in:
parent
9e9de7f6e2
commit
f98aead602
3 changed files with 91 additions and 42 deletions
|
@ -116,27 +116,57 @@ class LiteLLM:
|
|||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
self.params = locals()
|
||||
self.chat = Chat(self.params)
|
||||
self.chat = Chat(self.params, router_obj=None)
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, params):
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
self.completions = Completions(self.params)
|
||||
if self.params.get("acompletion", False) == True:
|
||||
self.params.pop("acompletion")
|
||||
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
||||
self.params, router_obj=router_obj
|
||||
)
|
||||
else:
|
||||
self.completions = Completions(self.params, router_obj=router_obj)
|
||||
|
||||
|
||||
class Completions:
|
||||
def __init__(self, params):
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
self.router_obj = router_obj
|
||||
|
||||
def create(self, messages, model=None, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
model = model or self.params.get("model")
|
||||
if self.router_obj is not None:
|
||||
response = self.router_obj.completion(
|
||||
model=model, messages=messages, **self.params
|
||||
)
|
||||
else:
|
||||
response = completion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncCompletions:
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
self.router_obj = router_obj
|
||||
|
||||
async def create(self, messages, model=None, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
model = model or self.params.get("model")
|
||||
if self.router_obj is not None:
|
||||
response = await self.router_obj.acompletion(
|
||||
model=model, messages=messages, **self.params
|
||||
)
|
||||
else:
|
||||
response = await acompletion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
|
||||
|
||||
@client
|
||||
async def acompletion(
|
||||
model: str,
|
||||
|
|
|
@ -230,7 +230,7 @@ class Router:
|
|||
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||
|
||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)
|
||||
|
||||
# default litellm args
|
||||
self.default_litellm_params = default_litellm_params
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# import sys, os
|
||||
# import traceback
|
||||
# import pytest
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
|
@ -16,51 +17,68 @@
|
|||
# from pydantic import BaseModel
|
||||
|
||||
# # This enables response_model keyword
|
||||
# # # from client.chat.completions.create
|
||||
# # client = instructor.patch(Router(model_list=[{
|
||||
# # "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# # "litellm_params": { # params for litellm completion/embedding call
|
||||
# # "model": "azure/chatgpt-v-2",
|
||||
# # "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# # "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# # "api_base": os.getenv("AZURE_API_BASE")
|
||||
# # }
|
||||
# # }]))
|
||||
|
||||
# # class UserDetail(BaseModel):
|
||||
# # name: str
|
||||
# # age: int
|
||||
|
||||
# # user = client.chat.completions.create(
|
||||
# # model="gpt-3.5-turbo",
|
||||
# # response_model=UserDetail,
|
||||
# # messages=[
|
||||
# # {"role": "user", "content": "Extract Jason is 25 years old"},
|
||||
# # ]
|
||||
# # )
|
||||
# # assert isinstance(model, UserExtract)
|
||||
|
||||
# # assert isinstance(user, UserDetail)
|
||||
# # assert user.name == "Jason"
|
||||
# # assert user.age == 25
|
||||
|
||||
# # print(f"user: {user}")
|
||||
# import instructor
|
||||
# from openai import AsyncOpenAI
|
||||
|
||||
# aclient = instructor.apatch(Router(model_list=[{
|
||||
# # from client.chat.completions.create
|
||||
# client = instructor.patch(
|
||||
# Router(
|
||||
# model_list=[
|
||||
# {
|
||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE")
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# },
|
||||
# }
|
||||
# }], default_litellm_params={"acompletion": True}))
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# class UserDetail(BaseModel):
|
||||
# name: str
|
||||
# age: int
|
||||
|
||||
|
||||
# user = client.chat.completions.create(
|
||||
# model="gpt-3.5-turbo",
|
||||
# response_model=UserDetail,
|
||||
# messages=[
|
||||
# {"role": "user", "content": "Extract Jason is 25 years old"},
|
||||
# ],
|
||||
# )
|
||||
|
||||
# assert isinstance(user, UserDetail)
|
||||
# assert user.name == "Jason"
|
||||
# assert user.age == 25
|
||||
|
||||
# print(f"user: {user}")
|
||||
# # import instructor
|
||||
# # from openai import AsyncOpenAI
|
||||
|
||||
# aclient = instructor.apatch(
|
||||
# Router(
|
||||
# model_list=[
|
||||
# {
|
||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# default_litellm_params={"acompletion": True},
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# class UserExtract(BaseModel):
|
||||
# name: str
|
||||
# age: int
|
||||
|
||||
|
||||
# async def main():
|
||||
# model = await aclient.chat.completions.create(
|
||||
# model="gpt-3.5-turbo",
|
||||
|
@ -71,4 +89,5 @@
|
|||
# )
|
||||
# print(f"model: {model}")
|
||||
|
||||
|
||||
# asyncio.run(main())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue