diff --git a/litellm/__init__.py b/litellm/__init__.py index 7c6864eac..689038cf7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -280,6 +280,7 @@ model_list = ( provider_list: List = [ "openai", "custom_openai", + "text-completion-openai", "cohere", "anthropic", "replicate", diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9d9120745..64473590b 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -521,12 +521,14 @@ class OpenAITextCompletion(BaseLLM): else: prompt = " ".join([message["content"] for message in messages]) # type: ignore + # don't send max retries to the api, if set + optional_params.pop("max_retries", None) + data = { "model": model, "prompt": prompt, **optional_params } - ## LOGGING logging_obj.pre_call( input=messages, diff --git a/litellm/main.py b/litellm/main.py index 973d21395..189241b32 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -193,7 +193,6 @@ async def acompletion(*args, **kwargs): # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) if kwargs.get("stream", False): # return an async generator - print_verbose(f"ENTERS STREAMING FOR ACOMPLETION") return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) else: return response @@ -664,17 +663,6 @@ def completion( prompt = messages[0]["content"] else: prompt = " ".join([message["content"] for message in messages]) # type: ignore - ## LOGGING - logging.pre_call( - input=prompt, - api_key=api_key, - additional_args={ - "openai_organization": litellm.organization, - "headers": headers, - "api_base": api_base, - "api_type": openai.api_type, - }, - ) ## COMPLETION CALL model_response = openai_text_completions.completion( model=model, @@ -1991,6 +1979,59 @@ def embedding( ###### Text Completion ################ +async def atext_completion(*args, **kwargs): + """ + Implemented to handle async streaming for the text completion endpoint + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO COMPLETION ### + kwargs["acompletion"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(text_completion, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + + if (custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openrouter" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "huggingface" + or custom_llm_provider == "ollama" + or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. + if kwargs.get("stream", False): + response = text_completion(*args, **kwargs) + else: + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False): # return an async generator + return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) + else: + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + ) + def text_completion( prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for. model: Optional[str]=None, # Optional: either `model` or `engine` can be set diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a55f18da1..2819f9d2a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -797,37 +797,6 @@ async def async_data_generator(response, user_api_key_dict): except: yield f"data: {json.dumps(chunk)}\n\n" -def litellm_completion(*args, **kwargs): - global user_temperature, user_request_timeout, user_max_tokens, user_api_base - call_type = kwargs.pop("call_type") - # override with user settings, these are params passed via cli - if user_temperature: - kwargs["temperature"] = user_temperature - if user_request_timeout: - kwargs["request_timeout"] = user_request_timeout - if user_max_tokens: - kwargs["max_tokens"] = user_max_tokens - if user_api_base: - kwargs["api_base"] = user_api_base - ## ROUTE TO CORRECT ENDPOINT ## - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - try: - if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list - if call_type == "chat_completion": - response = llm_router.completion(*args, **kwargs) - elif call_type == "text_completion": - response = llm_router.text_completion(*args, **kwargs) - else: - if call_type == "chat_completion": - response = litellm.completion(*args, **kwargs) - elif call_type == "text_completion": - response = litellm.text_completion(*args, **kwargs) - except Exception as e: - raise e - if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - return response - def get_litellm_model_info(model: dict = {}): model_info = model.get("model_info", {}) model_to_lookup = model.get("litellm_params", {}).get("model", None) @@ -907,7 +876,8 @@ def model_list(): @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) -async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): +async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + global user_temperature, user_request_timeout, user_max_tokens, user_api_base try: body = await request.body() body_str = body.decode() @@ -925,17 +895,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key ) if user_model: data["model"] = user_model - data["call_type"] = "text_completion" if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict.api_key else: data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - return litellm_completion( - **data - ) + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") + + ### ROUTE THE REQUEST ### + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] + if llm_router is not None and data["model"] in router_model_names: # model in router model list + response = await llm_router.atext_completion(**data) + elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router + response = await llm_router.atext_completion(**data, specific_deployment = True) + elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + response = await llm_router.atext_completion(**data) + else: # router is not set + response = await litellm.atext_completion(**data) + + print(f"final response: {response}") + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses + return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') + + background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + return response except Exception as e: print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") + traceback.print_exc() error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" try: diff --git a/litellm/router.py b/litellm/router.py index 6a4d04815..586a59bc0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -310,6 +310,45 @@ class Router: return self.function_with_retries(**kwargs) else: raise e + + async def atext_completion(self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs): + try: + kwargs.setdefault("metadata", {}).update({"model_group": model}) + messages=[{"role": "user", "content": prompt}] + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) + + data = deployment["litellm_params"].copy() + for k, v in self.default_litellm_params.items(): + if k not in data: # prioritize model-specific params > default router params + data[k] = v + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string + # call via litellm.atext_completion() + response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore + return response + except Exception as e: + if self.num_retries > 0: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_exception"] = e + kwargs["original_function"] = self.completion + return self.function_with_retries(**kwargs) + else: + raise e def embedding(self, model: str, diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 11c6d0e5b..9257a07f3 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -1,4 +1,4 @@ -import sys, os +import sys, os, asyncio import traceback from dotenv import load_dotenv @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion, text_completion, completion_cost +from litellm import embedding, completion, text_completion, completion_cost, atext_completion from litellm import RateLimitError @@ -61,7 +61,7 @@ def test_completion_openai_engine(): #print(response.choices[0].text) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_openai_engine() +# test_completion_openai_engine() def test_completion_chatgpt_prompt(): @@ -163,8 +163,23 @@ def test_text_completion_stream(): max_tokens=10, ) for chunk in response: - print(chunk) + print(f"chunk: {chunk}") except Exception as e: pytest.fail(f"GOT exception for HF In streaming{e}") -test_text_completion_stream() +# test_text_completion_stream() + +async def test_text_completion_async_stream(): + try: + response = await atext_completion( + model="text-completion-openai/text-davinci-003", + prompt="good morning", + stream=True, + max_tokens=10, + ) + async for chunk in response: + print(f"chunk: {chunk}") + except Exception as e: + pytest.fail(f"GOT exception for HF In streaming{e}") + +asyncio.run(test_text_completion_async_stream()) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 8ca4a8953..148d9f76f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5872,31 +5872,43 @@ class TextCompletionStreamWrapper: def __aiter__(self): return self + + def convert_to_text_completion_object(self, chunk: ModelResponse): + response = TextCompletionResponse() + response["id"] = chunk.get("id", None) + response["object"] = "text_completion" + response["created"] = response.get("created", None) + response["model"] = response.get("model", None) + text_choices = TextChoices() + text_choices["text"] = chunk["choices"][0]["delta"]["content"] + text_choices["index"] = response["choices"][0]["index"] + text_choices["finish_reason"] = response["choices"][0]["finish_reason"] + response["choices"] = [text_choices] + return response def __next__(self): # model_response = ModelResponse(stream=True, model=self.model) response = TextCompletionResponse() try: - while True: # loop until a non-empty string is found - # return this for all models - chunk = next(self.completion_stream) - response["id"] = chunk.get("id", None) - response["object"] = "text_completion" - response["created"] = response.get("created", None) - response["model"] = response.get("model", None) - text_choices = TextChoices() - text_choices["text"] = chunk["choices"][0]["delta"]["content"] - text_choices["index"] = response["choices"][0]["index"] - text_choices["finish_reason"] = response["choices"][0]["finish_reason"] - response["choices"] = [text_choices] - return response + for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + return processed_chunk + raise StopIteration except StopIteration: raise StopIteration except Exception as e: print(f"got exception {e}") # noqa + async def __anext__(self): try: - return next(self) + async for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + return processed_chunk + raise StopIteration except StopIteration: raise StopAsyncIteration