avoid overwriting litellm_params

This commit is contained in:
mc-marcocheng 2023-10-27 15:30:34 +08:00
parent 895cb5d0f9
commit f43d59fff8

View file

@ -1,5 +1,6 @@
from typing import Union, List, Dict, Optional
from datetime import datetime
from typing import Dict, List, Optional, Union
import litellm
@ -30,8 +31,7 @@ class Router:
redis_password: Optional[str] = None,
cache_responses: bool = False) -> None:
if model_list:
self.model_list = model_list
self.model_names = [m["model_name"] for m in model_list]
self.set_model_list(model_list)
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
@ -63,10 +63,8 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["messages"] = messages
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.completion(**{**data, **kwargs})
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
async def acompletion(self,
model: str,
@ -77,9 +75,7 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["messages"] = messages
data["caching"] = self.cache_responses
return await litellm.acompletion(**{**data, **kwargs})
return await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
def text_completion(self,
model: str,
@ -94,10 +90,8 @@ class Router:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["prompt"] = prompt
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.text_completion(**{**data, **kwargs})
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs})
def embedding(self,
model: str,
@ -108,10 +102,8 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
data["input"] = input
data["caching"] = self.cache_responses
# call via litellm.embedding()
return litellm.embedding(**{**data, **kwargs})
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
async def aembedding(self,
model: str,
@ -122,12 +114,11 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
data["input"] = input
data["caching"] = self.cache_responses
return await litellm.aembedding(**{**data, **kwargs})
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
def set_model_list(self, model_list: list):
self.model_list = model_list
self.model_names = [m["model_name"] for m in model_list]
def get_model_names(self):
return self.model_names