forked from phoenix/litellm-mirror
fix(router.py): enable additional params to be passe din
This commit is contained in:
parent
c4b550cfda
commit
05740fed9d
2 changed files with 98 additions and 43 deletions
|
@ -61,6 +61,8 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["messages"] = messages
|
data["messages"] = messages
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
data[key] = value
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.completion(**data)
|
return litellm.completion(**data)
|
||||||
|
|
||||||
|
@ -78,6 +80,8 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["prompt"] = prompt
|
data["prompt"] = prompt
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
data[key] = value
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.text_completion(**data)
|
return litellm.text_completion(**data)
|
||||||
|
|
||||||
|
@ -203,7 +207,10 @@ class Router:
|
||||||
# get value
|
# get value
|
||||||
cached_value = self.cache.get_cache(key)
|
cached_value = self.cache.get_cache(key)
|
||||||
# update value
|
# update value
|
||||||
|
try:
|
||||||
cached_value = cached_value + increment_value
|
cached_value = cached_value + increment_value
|
||||||
|
except:
|
||||||
|
cached_value = increment_value
|
||||||
# save updated value
|
# save updated value
|
||||||
self.cache.add_cache(result=cached_value, cache_key=key)
|
self.cache.add_cache(result=cached_value, cache_key=key)
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,15 @@ import pytest
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
model_list = [{ # list of model deployments
|
def test_multiple_deployments():
|
||||||
|
model_list = [{ # list of model deployments
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "azure/chatgpt-v-2",
|
"model": "azure/chatgpt-v-2",
|
||||||
|
@ -23,7 +25,7 @@ model_list = [{ # list of model deployments
|
||||||
},
|
},
|
||||||
"tpm": 240000,
|
"tpm": 240000,
|
||||||
"rpm": 1800
|
"rpm": 1800
|
||||||
}, {
|
}, {
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "azure/chatgpt-functioncalling",
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
@ -33,7 +35,7 @@ model_list = [{ # list of model deployments
|
||||||
},
|
},
|
||||||
"tpm": 240000,
|
"tpm": 240000,
|
||||||
"rpm": 1800
|
"rpm": 1800
|
||||||
}, {
|
}, {
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
|
@ -41,12 +43,12 @@ model_list = [{ # list of model deployments
|
||||||
},
|
},
|
||||||
"tpm": 1000000,
|
"tpm": 1000000,
|
||||||
"rpm": 9000
|
"rpm": 9000
|
||||||
}]
|
}]
|
||||||
|
|
||||||
router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore
|
router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore
|
||||||
|
|
||||||
completions = []
|
completions = []
|
||||||
with ThreadPoolExecutor(max_workers=100) as executor:
|
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
@ -55,7 +57,53 @@ with ThreadPoolExecutor(max_workers=100) as executor:
|
||||||
future = executor.submit(router.completion, **kwargs) # type: ignore
|
future = executor.submit(router.completion, **kwargs) # type: ignore
|
||||||
completions.append(future)
|
completions.append(future)
|
||||||
|
|
||||||
# Retrieve the results from the futures
|
# Retrieve the results from the futures
|
||||||
results = [future.result() for future in completions]
|
results = [future.result() for future in completions]
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
|
### FUNCTION CALLING
|
||||||
|
|
||||||
|
def test_function_calling():
|
||||||
|
litellm.set_verbose =True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-0613",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": "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",
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is the weather like in Boston?"}
|
||||||
|
]
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
test_function_calling()
|
Loading…
Add table
Add a link
Reference in a new issue