diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga.py index 3d0a2fc92..2a6e9c9ac 100644 --- a/litellm/llms/oobabooga.py +++ b/litellm/llms/oobabooga.py @@ -7,6 +7,7 @@ from typing import Callable, Optional from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt + class OobaboogaError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -15,6 +16,7 @@ class OobaboogaError(Exception): self.message ) # Call the base class constructor with the parameters it needs + def validate_environment(api_key): headers = { "accept": "application/json", @@ -24,6 +26,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Token {api_key}" return headers + def completion( model: str, messages: list, @@ -44,11 +47,13 @@ def completion( completion_url = model elif api_base: completion_url = api_base - else: - raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") + else: + raise OobaboogaError( + status_code=404, + message="API Base not set. Set one via completion(..,api_base='your-api-url')", + ) model = model - completion_url = completion_url + "/v1/chat/completions" data = { "messages": messages, @@ -57,40 +62,51 @@ def completion( ## LOGGING logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: completion_response = response.json() except: - raise OobaboogaError(message=response.text, status_code=response.status_code) + raise OobaboogaError( + message=response.text, status_code=response.status_code + ) if "error" in completion_response: - raise OobaboogaError(message=completion_response["error"],status_code=response.status_code,) + raise OobaboogaError( + message=completion_response["error"], + status_code=response.status_code, + ) else: try: - model_response["choices"][0]["message"]["content"] = completion_response["choices"][0]["message"]["content"] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["choices"][0]["message"]["content"] except: - raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) - - + raise OobaboogaError( + message=json.dumps(completion_response), + status_code=response.status_code, + ) model_response["created"] = int(time.time()) model_response["model"] = model @@ -103,25 +119,26 @@ def completion( return model_response - def embedding( - model: str, - input: list, - api_key: Optional[str] = None, - api_base: str = None, - logging_obj=None, - model_response=None, - optional_params=None, - encoding=None, - ): - + model: str, + input: list, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, +): # Create completion URL if "https" in model: embeddings_url = model elif api_base: embeddings_url = f"{api_base}/v1/embeddings" else: - raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") + raise OobaboogaError( + status_code=404, + message="API Base not set. Set one via completion(..,api_base='your-api-url')", + ) # Prepare request data data = {"input": input} @@ -130,7 +147,9 @@ def embedding( # Logging before API call if logging_obj: - logging_obj.pre_call(input=input, api_key=api_key, additional_args={"complete_input_dict": data}) + logging_obj.pre_call( + input=input, api_key=api_key, additional_args={"complete_input_dict": data} + ) # Send POST request headers = validate_environment(api_key) @@ -141,16 +160,24 @@ def embedding( # Check for errors in response if "error" in completion_response: - raise OobaboogaError(message=completion_response["error"], status_code=completion_response.get('status_code', 500)) + raise OobaboogaError( + message=completion_response["error"], + status_code=completion_response.get("status_code", 500), + ) # Process response data - model_response["data"]=[{"embedding": completion_response["data"][0]["embedding"], "index": 0, "object": "embedding"}] + model_response["data"] = [ + { + "embedding": completion_response["data"][0]["embedding"], + "index": 0, + "object": "embedding", + } + ] num_tokens = len(completion_response["data"][0]["embedding"]) - #Adding metadata to response - model_response.usage = Usage(prompt_tokens=num_tokens,total_tokens=num_tokens) - model_response["object"]="list" - model_response["model"]=model - + # Adding metadata to response + model_response.usage = Usage(prompt_tokens=num_tokens, total_tokens=num_tokens) + model_response["object"] = "list" + model_response["model"] = model return model_response diff --git a/litellm/main.py b/litellm/main.py index 2cf72ac33..4421f4c0e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2322,7 +2322,7 @@ def embedding( api_base=api_base, logging_obj=logging, optional_params=optional_params, - model_response= EmbeddingResponse() + model_response=EmbeddingResponse(), ) elif custom_llm_provider == "ollama": if aembedding == True: