fix(main.py): pass user_id + encoding_format for logging + to openai/azure

This commit is contained in:
Krrish Dholakia 2023-12-12 15:44:04 -08:00
parent 35fa176c97
commit 8b07a6c046
10 changed files with 82 additions and 26 deletions

Binary file not shown.

BIN
dist/litellm-1.12.6.dev5.tar.gz vendored Normal file

Binary file not shown.

View file

@ -415,7 +415,6 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING
logging_obj.pre_call(
input=input,

View file

@ -31,7 +31,8 @@ from litellm.utils import (
mock_completion_streaming_obj,
convert_to_model_response_object,
token_counter,
Usage
Usage,
get_optional_params_embeddings
)
from .llms import (
anthropic,
@ -1828,18 +1829,16 @@ def embedding(
tpm = kwargs.pop("tpm", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}
for param in non_default_params:
optional_params[param] = kwargs[param]
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
try:
response = None
logging = litellm_logging_obj

View file

@ -1,3 +1,11 @@
import sys, os, traceback
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.integrations.custom_logger import CustomLogger
import litellm
import inspect

View file

@ -916,7 +916,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except:
data = json.loads(body_str)
data["user"] = user_api_key_dict.user_id
data["user"] = data.get("user", user_api_key_dict.user_id)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
@ -1066,7 +1066,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
"body": copy.copy(data) # use copy instead of deepcopy
}
data["user"] = user_api_key_dict.user_id
data["user"] = data.get("user", user_api_key_dict.user_id)
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
@ -1081,7 +1081,6 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
print(f"received data: {data['input']}")
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
# check if non-openai/azure model called - e.g. for langchain integration
if llm_model_list is not None and data["model"] in router_model_names:
@ -1099,7 +1098,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
print(f'final data: {data}')
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data)

View file

@ -197,6 +197,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -507,7 +508,7 @@ async def test_async_embedding_openai():
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")

View file

@ -26,9 +26,12 @@ class MyCustomHandler(CustomLogger):
self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore
self.user = None # type: ignore
self.data_sent_to_api: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
@ -49,6 +52,7 @@ class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success")
print(f"received kwargs user: {kwargs['user']}")
self.async_success = True
if kwargs.get("model") == "text-embedding-ada-002":
self.async_success_embedding = True
@ -57,6 +61,7 @@ class MyCustomHandler(CustomLogger):
if kwargs.get("stream") == True:
self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs
self.user = kwargs.get("user", None)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
@ -95,14 +100,6 @@ def test_async_chat_openai_stream():
print(complete_streaming_response)
asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'")
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
print(f"complete_streaming_response: {complete_streaming_response}")
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response
assert [ord(c) for c in response1] == [ord(c) for c in response2]
@ -110,7 +107,7 @@ def test_async_chat_openai_stream():
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream()
# test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure():
try:
@ -290,9 +287,29 @@ async def test_async_custom_handler_embedding():
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_async_custom_handler_embedding())
from litellm import Cache
# asyncio.run(test_async_custom_handler_embedding())
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
# asyncio.run(test_async_custom_handler_embedding_optional_param())
def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random
try:
@ -325,4 +342,4 @@ def test_redis_cache_completion_stream():
print(e)
litellm.success_callback = []
raise e
test_redis_cache_completion_stream()
# test_redis_cache_completion_stream()

View file

@ -2112,6 +2112,39 @@ def get_litellm_params(
return litellm_params
def get_optional_params_embeddings(
# 2 optional params
user=None,
encoding_format=None,
custom_llm_provider="",
**kwargs
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider", None)
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"user": None,
"encoding_format": None
}
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True:
for k in non_default_params.keys():
passed_params.pop(k, None)
return passed_params
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
final_params = {**non_default_params, **kwargs}
return final_params
def get_optional_params( # use the openai defaults
# 12 optional params
functions=[],