forked from phoenix/litellm-mirror
feat(router.py): enable passing chat completion params for Router.chat.completion.create
This commit is contained in:
parent
95f9c6779d
commit
03303033e5
2 changed files with 58 additions and 31 deletions
|
@ -34,6 +34,7 @@ class Router:
|
||||||
cache_responses: bool = False,
|
cache_responses: bool = False,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
timeout: float = 600,
|
timeout: float = 600,
|
||||||
|
chat_completion_params = {}, # default params for Router.chat.completion.create
|
||||||
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||||
|
|
||||||
if model_list:
|
if model_list:
|
||||||
|
@ -43,6 +44,8 @@ class Router:
|
||||||
if num_retries:
|
if num_retries:
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
|
||||||
|
self.chat = litellm.Chat(params=chat_completion_params)
|
||||||
|
|
||||||
litellm.request_timeout = timeout
|
litellm.request_timeout = timeout
|
||||||
self.routing_strategy = routing_strategy
|
self.routing_strategy = routing_strategy
|
||||||
### HEALTH CHECK THREAD ###
|
### HEALTH CHECK THREAD ###
|
||||||
|
@ -65,7 +68,6 @@ class Router:
|
||||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||||
self.cache_responses = cache_responses
|
self.cache_responses = cache_responses
|
||||||
|
|
||||||
self.chat = litellm.Chat(params={})
|
|
||||||
|
|
||||||
|
|
||||||
def _start_health_check_thread(self):
|
def _start_health_check_thread(self):
|
||||||
|
|
|
@ -1,26 +1,54 @@
|
||||||
#### What this tests ####
|
# #### What this tests ####
|
||||||
# This tests the LiteLLM Class
|
# # This tests the LiteLLM Class
|
||||||
|
|
||||||
import sys, os
|
# import sys, os
|
||||||
import traceback
|
# import traceback
|
||||||
import pytest
|
# import pytest
|
||||||
sys.path.insert(
|
# sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
# 0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
# ) # Adds the parent directory to the system path
|
||||||
import litellm
|
# import litellm
|
||||||
|
# import asyncio
|
||||||
|
|
||||||
mr1 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo")
|
|
||||||
mr1.choices[0].finish_reason = "stop"
|
|
||||||
mr2 = litellm.ModelResponse(stream=True, model="gpt-3.5-turbo")
|
|
||||||
print(mr2.choices[0].finish_reason)
|
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
# from litellm import Router
|
# from litellm import Router
|
||||||
# import instructor
|
# import instructor
|
||||||
# from pydantic import BaseModel
|
# from pydantic import BaseModel
|
||||||
|
|
||||||
# # This enables response_model keyword
|
# # This enables response_model keyword
|
||||||
# # from client.chat.completions.create
|
# # # from client.chat.completions.create
|
||||||
# client = instructor.patch(Router(model_list=[{
|
# # client = instructor.patch(Router(model_list=[{
|
||||||
|
# # "model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
# # "litellm_params": { # params for litellm completion/embedding call
|
||||||
|
# # "model": "azure/chatgpt-v-2",
|
||||||
|
# # "api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
# # "api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
# # "api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
# # }
|
||||||
|
# # }]))
|
||||||
|
|
||||||
|
# # class UserDetail(BaseModel):
|
||||||
|
# # name: str
|
||||||
|
# # age: int
|
||||||
|
|
||||||
|
# # user = client.chat.completions.create(
|
||||||
|
# # model="gpt-3.5-turbo",
|
||||||
|
# # response_model=UserDetail,
|
||||||
|
# # messages=[
|
||||||
|
# # {"role": "user", "content": "Extract Jason is 25 years old"},
|
||||||
|
# # ]
|
||||||
|
# # )
|
||||||
|
# # assert isinstance(model, UserExtract)
|
||||||
|
|
||||||
|
# # assert isinstance(user, UserDetail)
|
||||||
|
# # assert user.name == "Jason"
|
||||||
|
# # assert user.age == 25
|
||||||
|
|
||||||
|
# # print(f"user: {user}")
|
||||||
|
# import instructor
|
||||||
|
# from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# aclient = instructor.apatch(Router(model_list=[{
|
||||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||||
# "litellm_params": { # params for litellm completion/embedding call
|
# "litellm_params": { # params for litellm completion/embedding call
|
||||||
# "model": "azure/chatgpt-v-2",
|
# "model": "azure/chatgpt-v-2",
|
||||||
|
@ -28,22 +56,19 @@ print(mr2.choices[0].finish_reason)
|
||||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
# "api_base": os.getenv("AZURE_API_BASE")
|
# "api_base": os.getenv("AZURE_API_BASE")
|
||||||
# }
|
# }
|
||||||
# }]))
|
# }], chat_completion_params={"acompletion": True}))
|
||||||
|
|
||||||
# class UserDetail(BaseModel):
|
# class UserExtract(BaseModel):
|
||||||
# name: str
|
# name: str
|
||||||
# age: int
|
# age: int
|
||||||
|
# async def main():
|
||||||
# user = client.chat.completions.create(
|
# model = await aclient.chat.completions.create(
|
||||||
# model="gpt-3.5-turbo",
|
# model="gpt-3.5-turbo",
|
||||||
# response_model=UserDetail,
|
# response_model=UserExtract,
|
||||||
# messages=[
|
# messages=[
|
||||||
# {"role": "user", "content": "Extract Jason is 25 years old"},
|
# {"role": "user", "content": "Extract jason is 25 years old"},
|
||||||
# ]
|
# ],
|
||||||
# )
|
# )
|
||||||
|
# print(f"model: {model}")
|
||||||
|
|
||||||
# assert isinstance(user, UserDetail)
|
# asyncio.run(main())
|
||||||
# assert user.name == "Jason"
|
|
||||||
# assert user.age == 25
|
|
||||||
|
|
||||||
# print(f"user: {user}")
|
|
Loading…
Add table
Add a link
Reference in a new issue