forked from phoenix/litellm-mirror
Merge pull request #679 from mc-marcocheng/router-kwargs
Fix data being overwritten
This commit is contained in:
commit
d244978247
2 changed files with 55 additions and 15 deletions
|
@ -61,10 +61,8 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["messages"] = messages
|
data["messages"] = messages
|
||||||
for key, value in kwargs.items():
|
|
||||||
data[key] = value
|
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.completion(**data)
|
return litellm.completion(**{**data, **kwargs})
|
||||||
|
|
||||||
def text_completion(self,
|
def text_completion(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -80,10 +78,8 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["prompt"] = prompt
|
data["prompt"] = prompt
|
||||||
for key, value in kwargs.items():
|
|
||||||
data[key] = value
|
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.text_completion(**data)
|
return litellm.text_completion(**{**data, **kwargs})
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -96,7 +92,7 @@ class Router:
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["input"] = input
|
data["input"] = input
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**data)
|
return litellm.embedding(**{**data, **kwargs})
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
|
|
@ -71,7 +71,7 @@ def test_function_calling():
|
||||||
"model_name": "gpt-3.5-turbo-0613",
|
"model_name": "gpt-3.5-turbo-0613",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "gpt-3.5-turbo-0613",
|
"model": "gpt-3.5-turbo-0613",
|
||||||
"api_key": "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",
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
},
|
||||||
"tpm": 100000,
|
"tpm": 100000,
|
||||||
"rpm": 10000,
|
"rpm": 10000,
|
||||||
|
@ -106,4 +106,48 @@ def test_function_calling():
|
||||||
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
test_function_calling()
|
### FUNCTION CALLING -> NORMAL COMPLETION
|
||||||
|
def test_litellm_params_not_overwritten_by_function_calling():
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-0613",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is the weather like in Boston?"}
|
||||||
|
]
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
_ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
||||||
|
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages)
|
||||||
|
assert response.choices[0].finish_reason != "function_call"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue