new logger client

This commit is contained in:
Krrish Dholakia 2023-08-28 14:56:20 -07:00
parent d48763a92f
commit a0f882d507
9 changed files with 235 additions and 195 deletions

View file

@ -92,6 +92,7 @@ def completion(
custom_llm_provider=None,
custom_api_base=None,
litellm_call_id=None,
litellm_logging_obj=None,
# model specific optional params
# used by text-bison only
top_k=40,
@ -100,6 +101,7 @@ def completion(
) -> ModelResponse:
args = locals()
try:
logging = litellm_logging_obj
if fallbacks != []:
return completion_with_fallbacks(**args)
if litellm.model_alias_map and model in litellm.model_alias_map:
@ -151,12 +153,7 @@ def completion(
litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map,
)
logging = Logging(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
logging.update_environment_variables(optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure":
# azure configs
openai.api_type = "azure"
@ -306,7 +303,7 @@ def completion(
response = openai.Completion.create(model=model, prompt=prompt, **optional_params)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model)
response = CustomStreamWrapper(response, model, logging_obj=logging)
return response
## LOGGING
logging.post_call(
@ -363,7 +360,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
# let the stream handler know this is replicate
response = CustomStreamWrapper(output, "replicate")
response = CustomStreamWrapper(output, "replicate", logging_obj=logging)
return response
response = ""
for item in output:
@ -413,7 +410,7 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model)
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
return response
response = model_response
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
@ -486,7 +483,7 @@ def completion(
response = co.generate(model=model, prompt=prompt, **optional_params)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(response, model)
response = CustomStreamWrapper(response, model, logging_obj=logging)
return response
## LOGGING
logging.post_call(
@ -532,7 +529,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface"
model_response, model, custom_llm_provider="huggingface", logging_obj=logging
)
return response
response = model_response
@ -572,7 +569,7 @@ def completion(
headers=headers,
)
response = CustomStreamWrapper(
res.iter_lines(), model, custom_llm_provider="together_ai"
res.iter_lines(), model, custom_llm_provider="together_ai", logging_obj=logging
)
return response
else:
@ -689,7 +686,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="ai21"
model_response, model, custom_llm_provider="ai21", logging_obj=logging
)
return response
@ -732,7 +729,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="baseten"
model_response, model, custom_llm_provider="baseten", logging_obj=logging
)
return response
response = model_response
@ -775,8 +772,6 @@ def completion(
)
return response
except Exception as e:
## LOGGING
logging.post_call(input=messages, api_key=api_key, original_response=e)
## Map to OpenAI Exception
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
@ -816,21 +811,12 @@ def batch_completion(*args, **kwargs):
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(
model, input=[], azure=False, force_timeout=60, litellm_call_id=None, logger_fn=None
model, input=[], azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None
):
try:
response = None
logging = Logging(
model=model,
messages=input,
optional_params={},
litellm_params={
"azure": azure,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
"litellm_call_id": litellm_call_id,
},
)
logging = litellm_logging_obj
logging.update_environment_variables(optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
if azure == True:
# azure configs
openai.api_type = "azure"
@ -849,7 +835,6 @@ def embedding(
)
## EMBEDDING CALL
response = openai.Embedding.create(input=input, engine=model)
print_verbose(f"response_value: {str(response)[:100]}")
elif model in litellm.open_ai_embedding_models:
openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1"
@ -867,15 +852,13 @@ def embedding(
)
## EMBEDDING CALL
response = openai.Embedding.create(input=input, model=model)
print_verbose(f"response_value: {str(response)[:100]}")
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
return response
except Exception as e:
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=e)
## Map to OpenAI Exception
raise exception_type(
model=model,