fix(main.py): pass user_id + encoding_format for logging + to openai/azure

This commit is contained in:
Krrish Dholakia 2023-12-12 15:44:04 -08:00
parent 35fa176c97
commit 8b07a6c046
10 changed files with 82 additions and 26 deletions

Binary file not shown.

BIN
dist/litellm-1.12.6.dev5.tar.gz vendored Normal file

Binary file not shown.

View file

@ -415,7 +415,6 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,

View file

@ -31,7 +31,8 @@ from litellm.utils import (
mock_completion_streaming_obj, mock_completion_streaming_obj,
convert_to_model_response_object, convert_to_model_response_object,
token_counter, token_counter,
Usage Usage,
get_optional_params_embeddings
) )
from .llms import ( from .llms import (
anthropic, anthropic,
@ -1828,18 +1829,16 @@ def embedding(
tpm = kwargs.pop("tpm", None) tpm = kwargs.pop("tpm", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None) metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None) aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"] openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}
for param in non_default_params:
optional_params[param] = kwargs[param]
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj

View file

@ -1,3 +1,11 @@
import sys, os, traceback
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import litellm import litellm
import inspect import inspect

View file

@ -916,7 +916,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["user"] = user_api_key_dict.user_id data["user"] = data.get("user", user_api_key_dict.user_id)
data["model"] = ( data["model"] = (
general_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -1066,7 +1066,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
"body": copy.copy(data) # use copy instead of deepcopy "body": copy.copy(data) # use copy instead of deepcopy
} }
data["user"] = user_api_key_dict.user_id data["user"] = data.get("user", user_api_key_dict.user_id)
data["model"] = ( data["model"] = (
general_settings.get("embedding_model", None) # server default general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -1081,7 +1081,6 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
print(f"received data: {data['input']}")
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
# check if non-openai/azure model called - e.g. for langchain integration # check if non-openai/azure model called - e.g. for langchain integration
if llm_model_list is not None and data["model"] in router_model_names: if llm_model_list is not None and data["model"] in router_model_names:
@ -1099,7 +1098,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
print(f'final data: {data}')
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data) response = await llm_router.aembedding(**data)

View file

@ -197,6 +197,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None))) assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str) assert isinstance(kwargs['log_event_type'], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -507,7 +508,7 @@ async def test_async_embedding_openai():
print(f"customHandler_failure.errors: {customHandler_failure.errors}") print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}") print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0 assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")

View file

@ -26,9 +26,12 @@ class MyCustomHandler(CustomLogger):
self.stream_collected_response = None # type: ignore self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore self.sync_stream_collected_response = None # type: ignore
self.user = None # type: ignore
self.data_sent_to_api: dict = {}
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print(f"Pre-API Call")
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call")
@ -49,6 +52,7 @@ class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success") print(f"On Async success")
print(f"received kwargs user: {kwargs['user']}")
self.async_success = True self.async_success = True
if kwargs.get("model") == "text-embedding-ada-002": if kwargs.get("model") == "text-embedding-ada-002":
self.async_success_embedding = True self.async_success_embedding = True
@ -57,6 +61,7 @@ class MyCustomHandler(CustomLogger):
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
self.stream_collected_response = response_obj self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs self.async_completion_kwargs = kwargs
self.user = kwargs.get("user", None)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure") print(f"On Async Failure")
@ -95,14 +100,6 @@ def test_async_chat_openai_stream():
print(complete_streaming_response) print(complete_streaming_response)
asyncio.run(call_gpt()) asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'") complete_streaming_response = complete_streaming_response.strip("'")
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
print(f"complete_streaming_response: {complete_streaming_response}")
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"] response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response response2 = complete_streaming_response
assert [ord(c) for c in response1] == [ord(c) for c in response2] assert [ord(c) for c in response1] == [ord(c) for c in response2]
@ -110,7 +107,7 @@ def test_async_chat_openai_stream():
except Exception as e: except Exception as e:
print(e) print(e)
pytest.fail(f"An error occurred - {str(e)}") pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream() # test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure(): def test_completion_azure_stream_moderation_failure():
try: try:
@ -290,9 +287,29 @@ async def test_async_custom_handler_embedding():
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_async_custom_handler_embedding()) # asyncio.run(test_async_custom_handler_embedding())
from litellm import Cache
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
# asyncio.run(test_async_custom_handler_embedding_optional_param())
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set # Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random import random
try: try:
@ -325,4 +342,4 @@ def test_redis_cache_completion_stream():
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []
raise e raise e
test_redis_cache_completion_stream() # test_redis_cache_completion_stream()

View file

@ -2112,6 +2112,39 @@ def get_litellm_params(
return litellm_params return litellm_params
def get_optional_params_embeddings(
# 2 optional params
user=None,
encoding_format=None,
custom_llm_provider="",
**kwargs
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider", None)
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"user": None,
"encoding_format": None
}
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True:
for k in non_default_params.keys():
passed_params.pop(k, None)
return passed_params
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
final_params = {**non_default_params, **kwargs}
return final_params
def get_optional_params( # use the openai defaults def get_optional_params( # use the openai defaults
# 12 optional params # 12 optional params
functions=[], functions=[],