diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index b1a3df7383..2539cfab58 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -4,9 +4,7 @@ from .base import BaseLLM from litellm.utils import ModelResponse, Choices, Message from typing import Callable, Optional from litellm import OpenAIConfig - -# This file just has the openai config classes. -# For implementation check out completion() in main.py +import aiohttp class AzureOpenAIError(Exception): def __init__(self, status_code, message): @@ -116,6 +114,7 @@ class AzureChatCompletion(BaseLLM): optional_params, litellm_params, logger_fn, + acompletion: bool = False, headers: Optional[dict]=None): super().completion() exception_mapping_worked = False @@ -157,6 +156,8 @@ class AzureChatCompletion(BaseLLM): ## RESPONSE OBJECT return response.iter_lines() + elif acompletion is True: + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) else: response = self._client_session.post( url=api_base, @@ -178,6 +179,18 @@ class AzureChatCompletion(BaseLLM): import traceback raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) + async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): + async with aiohttp.ClientSession() as session: + async with session.post(api_base, json=data, headers=headers) as response: + response_json = await response.json() + if response.status != 200: + raise AzureOpenAIError(status_code=response.status, message=response.text) + + + ## RESPONSE OBJECT + return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + + def embedding(self, model: str, input: list, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 56cc74573a..52801d2f0e 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -3,9 +3,8 @@ import types, requests from .base import BaseLLM from litellm.utils import ModelResponse, Choices, Message from typing import Callable, Optional +import aiohttp -# This file just has the openai config classes. -# For implementation check out completion() in main.py class OpenAIError(Exception): def __init__(self, status_code, message): @@ -184,22 +183,24 @@ class OpenAIChatCompletion(BaseLLM): OpenAIError(status_code=500, message="Invalid response object.") def completion(self, - model: Optional[str]=None, - messages: Optional[list]=None, - model_response: Optional[ModelResponse]=None, - print_verbose: Optional[Callable]=None, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - logging_obj=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict]=None): + model_response: ModelResponse, + model: Optional[str]=None, + messages: Optional[list]=None, + print_verbose: Optional[Callable]=None, + api_key: Optional[str]=None, + api_base: Optional[str]=None, + acompletion: bool = False, + logging_obj=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict]=None): super().completion() exception_mapping_worked = False try: if headers is None: headers = self.validate_environment(api_key=api_key) + api_base = f"{api_base}/chat/completions" if model is None or messages is None: raise OpenAIError(status_code=422, message=f"Missing model or messages") @@ -214,13 +215,13 @@ class OpenAIChatCompletion(BaseLLM): logging_obj.pre_call( input=messages, api_key=api_key, - additional_args={"headers": headers, "api_base": api_base}, + additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "data": data}, ) try: if "stream" in optional_params and optional_params["stream"] == True: response = self._client_session.post( - url=f"{api_base}/chat/completions", + url=api_base, json=data, headers=headers, stream=optional_params["stream"] @@ -230,9 +231,11 @@ class OpenAIChatCompletion(BaseLLM): ## RESPONSE OBJECT return response.iter_lines() + elif acompletion is True: + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) else: response = self._client_session.post( - url=f"{api_base}/chat/completions", + url=api_base, json=data, headers=headers, ) @@ -270,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM): import traceback raise OpenAIError(status_code=500, message=traceback.format_exc()) + async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): + async with aiohttp.ClientSession() as session: + async with session.post(api_base, json=data, headers=headers) as response: + response_json = await response.json() + if response.status != 200: + raise OpenAIError(status_code=response.status, message=response.text) + + + ## RESPONSE OBJECT + return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + def embedding(self, model: str, input: list, diff --git a/litellm/main.py b/litellm/main.py index 21eb580e99..4f2ab291ee 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7,7 +7,7 @@ # # Thank you ! We ❤️ you! - Krrish & Ishaan -import os, openai, sys, json, inspect +import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any from functools import partial import dotenv, traceback, random, asyncio, time, contextvars @@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion() azure_chat_completions = AzureChatCompletion() ####### COMPLETION ENDPOINTS ################ -async def acompletion(*args, **kwargs): +async def acompletion(model: str, messages: List = [], *args, **kwargs): """ Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) @@ -118,16 +118,31 @@ async def acompletion(*args, **kwargs): """ loop = asyncio.get_event_loop() - # Use a partial function to pass your keyword arguments + ### INITIALIZE LOGGING OBJECT ### + kwargs["litellm_call_id"] = str(uuid.uuid4()) + start_time = datetime.datetime.now() + logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time) + + ### PASS ARGS TO COMPLETION ### + kwargs["litellm_logging_obj"] = logging_obj kwargs["acompletion"] = True + kwargs["model"] = model + kwargs["messages"] = messages + # Use a partial function to pass your keyword arguments func = partial(completion, *args, **kwargs) # Add the context to the function ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + + if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all. + # Await normally + response = await completion(*args, **kwargs) + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) if kwargs.get("stream", False): # return an async generator # do not change this # for stream = True, always return an async generator @@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs): async for line in response ) else: + end_time = datetime.datetime.now() + # [OPTIONAL] ADD TO CACHE + if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object + litellm.cache.add_cache(response, *args, **kwargs) + + # LOG SUCCESS + logging_obj.success_handler(response, start_time, end_time) + # RETURN RESULT + response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai + return response def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): @@ -268,6 +293,7 @@ def completion( final_prompt_value = kwargs.get("final_prompt_value", None) bos_token = kwargs.get("bos_token", None) eos_token = kwargs.get("eos_token", None) + acompletion = kwargs.get("acompletion", False) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"] litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"] @@ -409,6 +435,7 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, logging_obj=logging, + acompletion=acompletion ) if "stream" in optional_params and optional_params["stream"] == True: response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging) @@ -472,6 +499,7 @@ def completion( print_verbose=print_verbose, api_key=api_key, api_base=api_base, + acompletion=acompletion, logging_obj=logging, optional_params=optional_params, litellm_params=litellm_params, diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 991c2006ea..c86bf89a9f 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -25,16 +25,19 @@ def test_sync_response(): def test_async_response(): import asyncio async def test_get_response(): - litellm.set_verbose = True user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: response = await acompletion(model="gpt-3.5-turbo", messages=messages) + print(f"response: {response}") + response = await acompletion(model="azure/chatgpt-v-2", messages=messages) + print(f"response: {response}") except Exception as e: pytest.fail(f"An exception occurred: {e}") response = asyncio.run(test_get_response()) # print(response) +test_async_response() def test_get_response_streaming(): import asyncio diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index 7b8db04810..a969003336 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -1,6 +1,7 @@ from litellm import completion, stream_chunk_builder import litellm import os, dotenv +import pytest dotenv.load_dotenv() user_message = "What is the current weather in Boston?" @@ -23,6 +24,7 @@ function_schema = { }, } +@pytest.mark.skip def test_stream_chunk_builder(): litellm.set_verbose = False litellm.api_key = os.environ["OPENAI_API_KEY"] diff --git a/litellm/utils.py b/litellm/utils.py index a75822ff07..16ae291dac 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -763,7 +763,8 @@ class Logging: ) elif isinstance(callback, CustomLogger): # custom logger class callback.log_failure_event( - model=self.model, + start_time=start_time, + end_time=end_time, messages=self.messages, kwargs=self.model_call_details, ) @@ -908,7 +909,7 @@ def client(original_function): def wrapper(*args, **kwargs): start_time = datetime.datetime.now() result = None - logging_obj = None + logging_obj = kwargs.get("litellm_logging_obj", None) # only set litellm_call_id if its not in kwargs if "litellm_call_id" not in kwargs: @@ -919,7 +920,8 @@ def client(original_function): raise ValueError("model param not passed in.") try: - logging_obj = function_setup(start_time, *args, **kwargs) + if logging_obj is None: + logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj # [OPTIONAL] CHECK BUDGET @@ -956,7 +958,8 @@ def client(original_function): return litellm.stream_chunk_builder(chunks) else: return result - + elif "acompletion" in kwargs and kwargs["acompletion"] == True: + return result # [OPTIONAL] ADD TO CACHE if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object @@ -1014,7 +1017,6 @@ def client(original_function): raise e return wrapper - ####### USAGE CALCULATOR ################ diff --git a/pyproject.toml b/pyproject.toml index d741b66449..06c84ed3e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ click = "*" jinja2 = "^3.1.2" certifi = "^2023.7.22" appdirs = "^1.4.4" +aiohttp = "*" [tool.poetry.scripts] litellm = 'litellm:run_server'