diff --git a/litellm/main.py b/litellm/main.py index 511e8fa5f5..6d6033e1cc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -84,7 +84,8 @@ async def acompletion(*args, **kwargs): loop = asyncio.get_event_loop() # Use a partial function to pass your keyword arguments - func = partial(completion, *args, **kwargs, acompletion=True) + kwargs["acompletion"] = True + func = partial(completion, *args, **kwargs) # Add the context to the function ctx = contextvars.copy_context() @@ -209,17 +210,12 @@ def completion( num_beams = kwargs.get('num_beams', 1) logger_fn = kwargs.get('logger_fn', None) verbose = kwargs.get('verbose', False) - azure = kwargs.get('azure', False) custom_llm_provider = kwargs.get('custom_llm_provider', None) - litellm_call_id = kwargs.get('litellm_call_id', None) litellm_logging_obj = kwargs.get('litellm_logging_obj', None) - use_client = kwargs.get('use_client', False) id = kwargs.get('id', None) metadata = kwargs.get('metadata', None) - top_k=40,# used by text-bison only task = kwargs.get('task', "text-generation-inference") return_full_text = kwargs.get('return_full_text', False) - remove_input = kwargs.get('remove_input', True) request_timeout = kwargs.get('request_timeout', 0) fallbacks = kwargs.get('fallbacks', []) caching = kwargs.get('caching', False) @@ -274,9 +270,9 @@ def completion( # params to identify the model model=model, custom_llm_provider=custom_llm_provider, - top_k=top_k, + top_k=kwargs.get('top_k', 40), task=task, - remove_input=remove_input, + remove_input=kwargs.get('remove_input', True), return_full_text=return_full_text, ) # For logging - save the values of the litellm-specific params passed in @@ -288,7 +284,7 @@ def completion( verbose=verbose, custom_llm_provider=custom_llm_provider, api_base=api_base, - litellm_call_id=litellm_call_id, + litellm_call_id=kwargs.get('litellm_call_id', None), model_alias_map=litellm.model_alias_map, completion_call_id=id, metadata=metadata @@ -1044,7 +1040,7 @@ def completion( logging.pre_call( input=prompt, api_key=None, additional_args={"endpoint": endpoint} ) - if acompletion == True: + if kwargs.get('acompletion', False) == True: async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt) return async_generator @@ -1158,7 +1154,7 @@ def completion( 'max_tokens': max_tokens, 'temperature': temperature, 'top_p': top_p, - 'top_k': top_k, + 'top_k': kwargs.get('top_k', 40), } }) response_json = resp.json()