fix: fix linting errors

This commit is contained in:
Krrish Dholakia 2024-07-11 13:36:55 -07:00
parent 6e9f048618
commit 389a51e05d
4 changed files with 97 additions and 134 deletions

View file

@ -1934,51 +1934,7 @@ def completion(
"""
Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility
"""
custom_llm_provider = "together_ai"
together_ai_key = (
api_key
or litellm.togetherai_api_key
or get_secret("TOGETHER_AI_TOKEN")
or get_secret("TOGETHERAI_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("TOGETHERAI_API_BASE")
or "https://api.together.xyz/inference"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = together_ai.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="together_ai",
logging_obj=logging,
)
return response
response = model_response
pass
elif custom_llm_provider == "palm":
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
@ -2461,10 +2417,10 @@ def completion(
## LOGGING
generator = ollama.get_ollama_response(
api_base,
model,
prompt,
optional_params,
api_base=api_base,
model=model,
prompt=prompt,
optional_params=optional_params,
logging_obj=logging,
acompletion=acompletion,
model_response=model_response,
@ -2490,11 +2446,11 @@ def completion(
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,
api_base=api_base,
api_key=api_key,
model=model,
messages=messages,
optional_params=optional_params,
logging_obj=logging,
acompletion=acompletion,
model_response=model_response,
@ -3465,7 +3421,7 @@ def embedding(
or api_base
or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434"
)
) # type: ignore
if isinstance(input, str):
input = [input]
if not all(isinstance(item, str) for item in input):
@ -3475,9 +3431,11 @@ def embedding(
llm_provider="ollama", # type: ignore
)
ollama_embeddings_fn = (
ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
ollama.ollama_aembeddings
if aembedding is True
else ollama.ollama_embeddings
)
response = ollama_embeddings_fn(
response = ollama_embeddings_fn( # type: ignore
api_base=api_base,
model=model,
prompts=input,