Merge pull request #1484 from BerriAI/litellm_access_key_metadata_in_callbacks

[Feat] Proxy - Access Key metadata in callbacks
This commit is contained in:
Ishaan Jaff 2024-01-17 18:08:08 -08:00 committed by GitHub
commit 15ae9182db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 51 deletions

View file

@ -166,13 +166,24 @@ class LangFuseLogger:
input, input,
response_obj, response_obj,
): ):
trace = self.Langfuse.trace( import langfuse
name=metadata.get("generation_name", "litellm-completion"),
input=input, tags = []
output=output, supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
user_id=metadata.get("trace_user_id", user_id),
id=metadata.get("trace_id", None), trace_params = {
) "name": metadata.get("generation_name", "litellm-completion"),
"input": input,
"output": output,
"user_id": metadata.get("trace_user_id", user_id),
"id": metadata.get("trace_id", None),
}
if supports_tags:
for key, value in metadata.items():
tags.append(f"{key}:{value}")
trace_params.update({"tags": tags})
trace = self.Langfuse.trace(**trace_params)
trace.generation( trace.generation(
name=metadata.get("generation_name", "litellm-completion"), name=metadata.get("generation_name", "litellm-completion"),

View file

@ -1423,16 +1423,14 @@ async def completion(
) )
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
if "metadata" in data: if "metadata" not in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"] = {}
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
else: data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"] = { data["metadata"]["headers"] = dict(request.headers)
"user_api_key": user_api_key_dict.api_key, data["metadata"]["endpoint"] = str(request.url)
"user_api_key_user_id": user_api_key_dict.user_id,
}
data["metadata"]["headers"] = dict(request.headers)
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
if user_temperature: if user_temperature:
data["temperature"] = user_temperature data["temperature"] = user_temperature
@ -1584,15 +1582,13 @@ async def chat_completion(
# if users are using user_api_key_auth, set `user` in `data` # if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id data["user"] = user_api_key_dict.user_id
if "metadata" in data: if "metadata" not in data:
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
else: data["metadata"]["headers"] = dict(request.headers)
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
@ -1754,14 +1750,13 @@ async def embeddings(
) )
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
if "metadata" in data: if "metadata" not in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"] = {}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
else: data["metadata"]["headers"] = dict(request.headers)
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
@ -1895,14 +1890,14 @@ async def image_generation(
) )
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key if "metadata" not in data:
data["metadata"]["headers"] = dict(request.headers) data["metadata"] = {}
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else: data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["endpoint"] = str(request.url)
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
@ -2471,15 +2466,13 @@ async def async_queue_request(
# if users are using user_api_key_auth, set `user` in `data` # if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id data["user"] = user_api_key_dict.user_id
if "metadata" in data: if "metadata" not in data:
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["headers"] = dict(request.headers)
else: data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli

View file

@ -182,6 +182,7 @@ def test_chat_completion(client):
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata")) print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
assert metadata is not None assert metadata is not None
assert "user_api_key" in metadata assert "user_api_key" in metadata
assert "user_api_key_metadata" in metadata
assert "headers" in metadata assert "headers" in metadata
config_model_info = litellm_params.get("model_info") config_model_info = litellm_params.get("model_info")
proxy_server_request_object = litellm_params.get("proxy_server_request") proxy_server_request_object = litellm_params.get("proxy_server_request")