clean out args passed to completion

This commit is contained in:
ishaan-jaff 2023-09-28 17:43:54 -07:00
parent 8d11a0c722
commit 1b74cd2790

View file

@ -84,7 +84,8 @@ async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs, acompletion=True)
kwargs["acompletion"] = True
func = partial(completion, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
@ -209,17 +210,12 @@ def completion(
num_beams = kwargs.get('num_beams', 1)
logger_fn = kwargs.get('logger_fn', None)
verbose = kwargs.get('verbose', False)
azure = kwargs.get('azure', False)
custom_llm_provider = kwargs.get('custom_llm_provider', None)
litellm_call_id = kwargs.get('litellm_call_id', None)
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
use_client = kwargs.get('use_client', False)
id = kwargs.get('id', None)
metadata = kwargs.get('metadata', None)
top_k=40,# used by text-bison only
task = kwargs.get('task', "text-generation-inference")
return_full_text = kwargs.get('return_full_text', False)
remove_input = kwargs.get('remove_input', True)
request_timeout = kwargs.get('request_timeout', 0)
fallbacks = kwargs.get('fallbacks', [])
caching = kwargs.get('caching', False)
@ -274,9 +270,9 @@ def completion(
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
top_k=top_k,
top_k=kwargs.get('top_k', 40),
task=task,
remove_input=remove_input,
remove_input=kwargs.get('remove_input', True),
return_full_text=return_full_text,
)
# For logging - save the values of the litellm-specific params passed in
@ -288,7 +284,7 @@ def completion(
verbose=verbose,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
litellm_call_id=litellm_call_id,
litellm_call_id=kwargs.get('litellm_call_id', None),
model_alias_map=litellm.model_alias_map,
completion_call_id=id,
metadata=metadata
@ -1044,7 +1040,7 @@ def completion(
logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
)
if acompletion == True:
if kwargs.get('acompletion', False) == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)
return async_generator
@ -1158,7 +1154,7 @@ def completion(
'max_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
'top_k': kwargs.get('top_k', 40),
}
})
response_json = resp.json()