mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) text_completion add docstring
This commit is contained in:
parent
5cf2239aaa
commit
cac3148dff
1 changed files with 128 additions and 24 deletions
150
litellm/main.py
150
litellm/main.py
|
@ -53,7 +53,7 @@ from .llms.openai import OpenAIChatCompletion
|
||||||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Callable, List, Optional, Dict
|
from typing import Callable, List, Optional, Dict, Union
|
||||||
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -1858,38 +1858,129 @@ def embedding(
|
||||||
|
|
||||||
|
|
||||||
###### Text Completion ################
|
###### Text Completion ################
|
||||||
def text_completion(*args, **kwargs):
|
def text_completion(
|
||||||
|
model: str, # Required: ID of the model to use.
|
||||||
|
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
||||||
|
best_of: Optional[int] = None, # Optional: Generates best_of completions server-side.
|
||||||
|
echo: Optional[bool] = None, # Optional: Echo back the prompt in addition to the completion.
|
||||||
|
frequency_penalty: Optional[float] = None, # Optional: Penalize new tokens based on their existing frequency.
|
||||||
|
logit_bias: Optional[Dict[int, int]] = None, # Optional: Modify the likelihood of specified tokens.
|
||||||
|
logprobs: Optional[int] = None, # Optional: Include the log probabilities on the most likely tokens.
|
||||||
|
max_tokens: Optional[int] = None, # Optional: The maximum number of tokens to generate in the completion.
|
||||||
|
n: Optional[int] = None, # Optional: How many completions to generate for each prompt.
|
||||||
|
presence_penalty: Optional[float] = None, # Optional: Penalize new tokens based on whether they appear in the text so far.
|
||||||
|
stop: Optional[Union[str, List[str]]] = None, # Optional: Sequences where the API will stop generating further tokens.
|
||||||
|
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
||||||
|
suffix: Optional[str] = None, # Optional: The suffix that comes after a completion of inserted text.
|
||||||
|
temperature: Optional[float] = None, # Optional: Sampling temperature to use.
|
||||||
|
top_p: Optional[float] = None, # Optional: Nucleus sampling parameter.
|
||||||
|
user: Optional[str] = None, # Optional: A unique identifier representing your end-user.
|
||||||
|
|
||||||
|
# set api_base, api_version, api_key
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||||
|
|
||||||
|
# Optional liteLLM function params
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
global print_verbose
|
global print_verbose
|
||||||
import copy
|
import copy
|
||||||
"""
|
"""
|
||||||
This maps to the Openai.Completion.create format, which has a different I/O (accepts prompt, returning ["choices"]["text"].
|
Generate text completions using the OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): ID of the model to use.
|
||||||
|
prompt (Union[str, List[Union[str, List[Union[str, List[int]]]]]): The prompt(s) to generate completions for.
|
||||||
|
best_of (Optional[int], optional): Generates best_of completions server-side. Defaults to 1.
|
||||||
|
echo (Optional[bool], optional): Echo back the prompt in addition to the completion. Defaults to False.
|
||||||
|
frequency_penalty (Optional[float], optional): Penalize new tokens based on their existing frequency. Defaults to 0.
|
||||||
|
logit_bias (Optional[Dict[int, int]], optional): Modify the likelihood of specified tokens. Defaults to None.
|
||||||
|
logprobs (Optional[int], optional): Include the log probabilities on the most likely tokens. Defaults to None.
|
||||||
|
max_tokens (Optional[int], optional): The maximum number of tokens to generate in the completion. Defaults to 16.
|
||||||
|
n (Optional[int], optional): How many completions to generate for each prompt. Defaults to 1.
|
||||||
|
presence_penalty (Optional[float], optional): Penalize new tokens based on whether they appear in the text so far. Defaults to 0.
|
||||||
|
stop (Optional[Union[str, List[str]]], optional): Sequences where the API will stop generating further tokens. Defaults to None.
|
||||||
|
stream (Optional[bool], optional): Whether to stream back partial progress. Defaults to False.
|
||||||
|
suffix (Optional[str], optional): The suffix that comes after a completion of inserted text. Defaults to None.
|
||||||
|
temperature (Optional[float], optional): Sampling temperature to use. Defaults to 1.
|
||||||
|
top_p (Optional[float], optional): Nucleus sampling parameter. Defaults to 1.
|
||||||
|
user (Optional[str], optional): A unique identifier representing your end-user.
|
||||||
|
Returns:
|
||||||
|
TextCompletionResponse: A response object containing the generated completion and associated metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Your example of how to use this function goes here.
|
||||||
"""
|
"""
|
||||||
if "engine" in kwargs:
|
if "engine" in kwargs:
|
||||||
kwargs["model"] = kwargs["engine"]
|
model = kwargs["engine"]
|
||||||
kwargs.pop("engine")
|
kwargs.pop("engine")
|
||||||
|
|
||||||
# input validation
|
|
||||||
if "prompt" not in kwargs:
|
|
||||||
raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`")
|
|
||||||
|
|
||||||
text_completion_response = TextCompletionResponse()
|
text_completion_response = TextCompletionResponse()
|
||||||
model = kwargs["model"]
|
|
||||||
prompt = kwargs["prompt"]
|
optional_params = {}
|
||||||
|
if best_of is not None:
|
||||||
|
optional_params["best_of"] = best_of
|
||||||
|
if echo is not None:
|
||||||
|
optional_params["echo"] = echo
|
||||||
|
if frequency_penalty is not None:
|
||||||
|
optional_params["frequency_penalty"] = frequency_penalty
|
||||||
|
if logit_bias is not None:
|
||||||
|
optional_params["logit_bias"] = logit_bias
|
||||||
|
if logprobs is not None:
|
||||||
|
optional_params["logprobs"] = logprobs
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
if n is not None:
|
||||||
|
optional_params["n"] = n
|
||||||
|
if presence_penalty is not None:
|
||||||
|
optional_params["presence_penalty"] = presence_penalty
|
||||||
|
if stop is not None:
|
||||||
|
optional_params["stop"] = stop
|
||||||
|
if stream is not None:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if suffix is not None:
|
||||||
|
optional_params["suffix"] = suffix
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
|
if user is not None:
|
||||||
|
optional_params["user"] = user
|
||||||
|
if api_base is not None:
|
||||||
|
optional_params["api_base"] = api_base
|
||||||
|
if api_version is not None:
|
||||||
|
optional_params["api_version"] = api_version
|
||||||
|
if api_key is not None:
|
||||||
|
optional_params["api_key"] = api_key
|
||||||
|
if custom_llm_provider is not None:
|
||||||
|
optional_params["custom_llm_provider"] = custom_llm_provider
|
||||||
|
|
||||||
# get custom_llm_provider
|
# get custom_llm_provider
|
||||||
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model)
|
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
||||||
|
|
||||||
|
|
||||||
if custom_llm_provider == "text-completion-openai":
|
if custom_llm_provider == "text-completion-openai":
|
||||||
# text-davinci-003 and openai text completion models
|
# text-davinci-003 and openai text completion models
|
||||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
messages = [{"role": "system", "content": prompt}]
|
||||||
kwargs["messages"] = messages
|
kwargs.pop("prompt", None)
|
||||||
kwargs.pop("prompt")
|
response = completion(
|
||||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
model = model,
|
||||||
|
messages=messages,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
# assume the response is the openai response object
|
||||||
# return raw response from openai
|
# return raw response from openai
|
||||||
return response._hidden_params.get("original_response", None)
|
return response._hidden_params.get("original_response", None)
|
||||||
|
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||||
if kwargs.get("echo", False) == True:
|
if echo == True:
|
||||||
# for tgi llms
|
# for tgi llms
|
||||||
if "top_n_tokens" not in kwargs:
|
if "top_n_tokens" not in kwargs:
|
||||||
kwargs["top_n_tokens"] = 3
|
kwargs["top_n_tokens"] = 3
|
||||||
|
@ -1902,9 +1993,13 @@ def text_completion(*args, **kwargs):
|
||||||
responses = [None for x in prompt] # init responses
|
responses = [None for x in prompt] # init responses
|
||||||
for i, request in enumerate(prompt):
|
for i, request in enumerate(prompt):
|
||||||
decoded_prompt = tokenizer.decode(request)
|
decoded_prompt = tokenizer.decode(request)
|
||||||
new_kwargs = copy.deepcopy(kwargs)
|
response = text_completion(
|
||||||
new_kwargs["prompt"] = decoded_prompt
|
model = model,
|
||||||
response = text_completion(**new_kwargs)
|
prompt=decoded_prompt,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
responses[i] = response["choices"][0]
|
responses[i] = response["choices"][0]
|
||||||
|
|
||||||
text_completion_response["id"] = response["id"]
|
text_completion_response["id"] = response["id"]
|
||||||
|
@ -1916,10 +2011,19 @@ def text_completion(*args, **kwargs):
|
||||||
|
|
||||||
return text_completion_response
|
return text_completion_response
|
||||||
else:
|
else:
|
||||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
# check if non default values passed in for best_of, echo, logprobs, suffix
|
||||||
kwargs["messages"] = messages
|
# these are the params supported by Completion() but not ChatCompletion
|
||||||
kwargs.pop("prompt")
|
|
||||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
# default case, non OpenAI requests go through here
|
||||||
|
messages = [{"role": "system", "content": prompt}]
|
||||||
|
kwargs.pop("prompt", None)
|
||||||
|
response = completion(
|
||||||
|
model = model,
|
||||||
|
messages=messages,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue