mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
v0
This commit is contained in:
parent
6134b655e8
commit
b730482aaf
1 changed files with 64 additions and 6 deletions
|
@ -242,22 +242,80 @@ class Router:
|
|||
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
|
||||
|
||||
def completion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
self,
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: Optional[dict] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
deployment_id=None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Example usage:
|
||||
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
"""
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._completion
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
completion_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"functions": functions,
|
||||
"function_call": function_call,
|
||||
"timeout": timeout,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
"stream": stream,
|
||||
"stop": stop,
|
||||
"max_tokens": max_tokens,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"user": user,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"deployment_id": deployment_id,
|
||||
"base_url": base_url,
|
||||
"api_version": api_version,
|
||||
"api_key": api_key,
|
||||
"model_list": model_list,
|
||||
"original_function": self._completion,
|
||||
}
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# Submit the function to the executor with a timeout
|
||||
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
||||
future = executor.submit(
|
||||
self.function_with_fallbacks, **kwargs, **completion_kwargs
|
||||
)
|
||||
response = future.result(timeout=timeout) # type: ignore
|
||||
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue