diff --git a/.circleci/config.yml b/.circleci/config.yml index d155823c6..c1224159a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -150,6 +150,7 @@ jobs: -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ + -e OPENAI_API_KEY=$OPENAI_API_KEY \ --name my-app \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ my-app:latest \ diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 905b9424e..84b09d726 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1079,7 +1079,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): metadata = ( litellm_params.get("metadata", {}) or {} ) # if litellm_params['metadata'] == None - call_type = kwargs.get("call_type", "litellm.completion") + call_type = kwargs.get("call_type") cache_hit = kwargs.get("cache_hit", False) usage = response_obj["usage"] if type(usage) == litellm.Usage: @@ -1118,6 +1118,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): "completion_tokens": usage.get("completion_tokens", 0), } + verbose_proxy_logger.debug(f"SpendTable: created payload - payload: {payload}\n\n") json_fields = [ field for field, field_type in LiteLLM_SpendLogs.__annotations__.items() diff --git a/litellm/utils.py b/litellm/utils.py index fe899388f..587a44895 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -804,6 +804,7 @@ class Logging: "stream": self.stream, "user": user, "call_type": str(self.call_type), + "litellm_call_id": self.litellm_call_id, **self.optional_params, **additional_params, } @@ -1056,6 +1057,7 @@ class Logging: and ( isinstance(result, ModelResponse) or isinstance(result, EmbeddingResponse) + or isinstance(result, ImageResponse) ) and self.stream != True ): # handle streaming separately @@ -1063,11 +1065,24 @@ class Logging: if self.model_call_details.get("cache_hit", False) == True: self.model_call_details["response_cost"] = 0.0 else: - self.model_call_details[ - "response_cost" - ] = litellm.completion_cost( - completion_response=result, - ) + result._hidden_params["optional_params"] = self.optional_params + if ( + self.call_type == CallTypes.aimage_generation.value + or self.call_type == CallTypes.image_generation.value + ): + self.model_call_details[ + "response_cost" + ] = litellm.completion_cost( + completion_response=result, + model=self.model, + call_type=self.call_type, + ) + else: + self.model_call_details[ + "response_cost" + ] = litellm.completion_cost( + completion_response=result, call_type=self.call_type + ) verbose_logger.debug( f"Model={self.model}; cost={self.model_call_details['response_cost']}" ) @@ -3174,6 +3189,16 @@ def completion_cost( messages: List = [], completion="", total_time=0.0, # used for replicate, sagemaker + call_type: Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atext_completion", + "text_completion", + "image_generation", + "aimage_generation", + ] = "completion", ### REGION ### custom_llm_provider=None, region_name=None, # used for bedrock pricing @@ -3232,6 +3257,19 @@ def completion_cost( region_name = completion_response._hidden_params.get( "region_name", region_name ) + size = completion_response._hidden_params.get( + "optional_params", {} + ).get( + "size", "1024-x-1024" + ) # openai default + quality = completion_response._hidden_params.get( + "optional_params", {} + ).get( + "quality", "standard" + ) # openai default + n = completion_response._hidden_params.get("optional_params", {}).get( + "n", 1 + ) # openai default else: if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) @@ -3243,7 +3281,10 @@ def completion_cost( f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}" ) - if size is not None and n is not None: + if ( + call_type == CallTypes.image_generation.value + or call_type == CallTypes.aimage_generation.value + ): ### IMAGE GENERATION COST CALCULATION ### image_gen_model_name = f"{size}/{model}" image_gen_model_name_with_quality = image_gen_model_name diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 2c123d156..1d499aa7d 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -42,6 +42,9 @@ model_list: api_version: 2023-06-01-preview api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_key: os.environ/AZURE_API_KEY + - model_name: openai-dall-e-3 + litellm_params: + model: dall-e-3 litellm_settings: drop_params: True diff --git a/tests/test_keys.py b/tests/test_keys.py index 9cbcc25e1..6740308ac 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -14,7 +14,11 @@ import litellm async def generate_key( - session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"] + session, + i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], ): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -129,6 +133,39 @@ async def chat_completion(session, key, model="gpt-4"): pass +async def image_generation(session, key, model="dall-e-3"): + url = "http://0.0.0.0:4000/v1/images/generations" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "prompt": "A cute baby sea otter", + } + + for i in range(3): + try: + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception( + f"Request did not return a 200 status code: {status}. Response: {response_text}" + ) + + return await response.json() + except Exception as e: + if "Request did not return a 200 status code" in str(e): + raise e + else: + pass + + async def chat_completion_streaming(session, key, model="gpt-4"): client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") messages = [ @@ -357,6 +394,40 @@ async def test_key_info_spend_values_streaming(): assert rounded_response_cost == rounded_key_info_spend +@pytest.mark.asyncio +async def test_key_info_spend_values_image_generation(): + """ + Test to ensure spend is correctly calculated + - create key + - make image gen call + - assert cost is expected value + """ + + async def retry_request(func, *args, _max_attempts=5, **kwargs): + for attempt in range(_max_attempts): + try: + return await func(*args, **kwargs) + except aiohttp.client_exceptions.ClientOSError as e: + if attempt + 1 == _max_attempts: + raise # re-raise the last ClientOSError if all attempts failed + print(f"Attempt {attempt+1} failed, retrying...") + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=600) + ) as session: + ## Test Spend Update ## + # completion + key_gen = await generate_key(session=session, i=0) + key = key_gen["key"] + response = await image_generation(session=session, key=key) + await asyncio.sleep(5) + key_info = await retry_request( + get_key_info, session=session, get_key=key, call_key=key + ) + spend = key_info["info"]["spend"] + assert spend > 0 + + @pytest.mark.asyncio async def test_key_with_budgets(): """