mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
clean out args passed to completion
This commit is contained in:
parent
8d11a0c722
commit
1b74cd2790
1 changed files with 7 additions and 11 deletions
|
@ -84,7 +84,8 @@ async def acompletion(*args, **kwargs):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, *args, **kwargs, acompletion=True)
|
kwargs["acompletion"] = True
|
||||||
|
func = partial(completion, *args, **kwargs)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
|
@ -209,17 +210,12 @@ def completion(
|
||||||
num_beams = kwargs.get('num_beams', 1)
|
num_beams = kwargs.get('num_beams', 1)
|
||||||
logger_fn = kwargs.get('logger_fn', None)
|
logger_fn = kwargs.get('logger_fn', None)
|
||||||
verbose = kwargs.get('verbose', False)
|
verbose = kwargs.get('verbose', False)
|
||||||
azure = kwargs.get('azure', False)
|
|
||||||
custom_llm_provider = kwargs.get('custom_llm_provider', None)
|
custom_llm_provider = kwargs.get('custom_llm_provider', None)
|
||||||
litellm_call_id = kwargs.get('litellm_call_id', None)
|
|
||||||
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
|
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
|
||||||
use_client = kwargs.get('use_client', False)
|
|
||||||
id = kwargs.get('id', None)
|
id = kwargs.get('id', None)
|
||||||
metadata = kwargs.get('metadata', None)
|
metadata = kwargs.get('metadata', None)
|
||||||
top_k=40,# used by text-bison only
|
|
||||||
task = kwargs.get('task', "text-generation-inference")
|
task = kwargs.get('task', "text-generation-inference")
|
||||||
return_full_text = kwargs.get('return_full_text', False)
|
return_full_text = kwargs.get('return_full_text', False)
|
||||||
remove_input = kwargs.get('remove_input', True)
|
|
||||||
request_timeout = kwargs.get('request_timeout', 0)
|
request_timeout = kwargs.get('request_timeout', 0)
|
||||||
fallbacks = kwargs.get('fallbacks', [])
|
fallbacks = kwargs.get('fallbacks', [])
|
||||||
caching = kwargs.get('caching', False)
|
caching = kwargs.get('caching', False)
|
||||||
|
@ -274,9 +270,9 @@ def completion(
|
||||||
# params to identify the model
|
# params to identify the model
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
top_k=top_k,
|
top_k=kwargs.get('top_k', 40),
|
||||||
task=task,
|
task=task,
|
||||||
remove_input=remove_input,
|
remove_input=kwargs.get('remove_input', True),
|
||||||
return_full_text=return_full_text,
|
return_full_text=return_full_text,
|
||||||
)
|
)
|
||||||
# For logging - save the values of the litellm-specific params passed in
|
# For logging - save the values of the litellm-specific params passed in
|
||||||
|
@ -288,7 +284,7 @@ def completion(
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
litellm_call_id=litellm_call_id,
|
litellm_call_id=kwargs.get('litellm_call_id', None),
|
||||||
model_alias_map=litellm.model_alias_map,
|
model_alias_map=litellm.model_alias_map,
|
||||||
completion_call_id=id,
|
completion_call_id=id,
|
||||||
metadata=metadata
|
metadata=metadata
|
||||||
|
@ -1044,7 +1040,7 @@ def completion(
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if kwargs.get('acompletion', False) == True:
|
||||||
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)
|
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)
|
||||||
return async_generator
|
return async_generator
|
||||||
|
|
||||||
|
@ -1158,7 +1154,7 @@ def completion(
|
||||||
'max_tokens': max_tokens,
|
'max_tokens': max_tokens,
|
||||||
'temperature': temperature,
|
'temperature': temperature,
|
||||||
'top_p': top_p,
|
'top_p': top_p,
|
||||||
'top_k': top_k,
|
'top_k': kwargs.get('top_k', 40),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
response_json = resp.json()
|
response_json = resp.json()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue