diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga.py index 47d88cc79..3d0a2fc92 100644 --- a/litellm/llms/oobabooga.py +++ b/litellm/llms/oobabooga.py @@ -7,7 +7,6 @@ from typing import Callable, Optional from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt - class OobaboogaError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -16,7 +15,6 @@ class OobaboogaError(Exception): self.message ) # Call the base class constructor with the parameters it needs - def validate_environment(api_key): headers = { "accept": "application/json", @@ -26,7 +24,6 @@ def validate_environment(api_key): headers["Authorization"] = f"Token {api_key}" return headers - def completion( model: str, messages: list, @@ -47,93 +44,113 @@ def completion( completion_url = model elif api_base: completion_url = api_base - else: - raise OobaboogaError( - status_code=404, - message="API Base not set. Set one via completion(..,api_base='your-api-url')", - ) + else: + raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") model = model - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - prompt = prompt_factory(model=model, messages=messages) - completion_url = completion_url + "/api/v1/generate" + + completion_url = completion_url + "/v1/chat/completions" data = { - "prompt": prompt, + "messages": messages, **optional_params, } ## LOGGING + logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL + response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data), - stream=optional_params["stream"] if "stream" in optional_params else False, + completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: completion_response = response.json() except: - raise OobaboogaError( - message=response.text, status_code=response.status_code - ) + raise OobaboogaError(message=response.text, status_code=response.status_code) if "error" in completion_response: - raise OobaboogaError( - message=completion_response["error"], - status_code=response.status_code, - ) + raise OobaboogaError(message=completion_response["error"],status_code=response.status_code,) else: try: - model_response["choices"][0]["message"][ - "content" - ] = completion_response["results"][0]["text"] + model_response["choices"][0]["message"]["content"] = completion_response["choices"][0]["message"]["content"] except: - raise OobaboogaError( - message=json.dumps(completion_response), - status_code=response.status_code, - ) + raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) + - ## CALCULATING USAGE - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"]["content"]) - ) model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=completion_response["usage"]["prompt_tokens"], + completion_tokens=completion_response["usage"]["completion_tokens"], + total_tokens=completion_response["usage"]["total_tokens"], ) model_response.usage = usage return model_response -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass + +def embedding( + model: str, + input: list, + api_key: Optional[str] = None, + api_base: str = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + ): + + # Create completion URL + if "https" in model: + embeddings_url = model + elif api_base: + embeddings_url = f"{api_base}/v1/embeddings" + else: + raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") + + # Prepare request data + data = {"input": input} + if optional_params: + data.update(optional_params) + + # Logging before API call + if logging_obj: + logging_obj.pre_call(input=input, api_key=api_key, additional_args={"complete_input_dict": data}) + + # Send POST request + headers = validate_environment(api_key) + response = requests.post(embeddings_url, headers=headers, json=data) + if not response.ok: + raise OobaboogaError(message=response.text, status_code=response.status_code) + completion_response = response.json() + + # Check for errors in response + if "error" in completion_response: + raise OobaboogaError(message=completion_response["error"], status_code=completion_response.get('status_code', 500)) + + # Process response data + model_response["data"]=[{"embedding": completion_response["data"][0]["embedding"], "index": 0, "object": "embedding"}] + + num_tokens = len(completion_response["data"][0]["embedding"]) + #Adding metadata to response + model_response.usage = Usage(prompt_tokens=num_tokens,total_tokens=num_tokens) + model_response["object"]="list" + model_response["model"]=model + + + return model_response diff --git a/litellm/main.py b/litellm/main.py index 455549e55..2cf72ac33 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2314,6 +2314,16 @@ def embedding( optional_params=optional_params, model_response=EmbeddingResponse(), ) + elif custom_llm_provider == "oobabooga": + response = oobabooga.embedding( + model=model, + input=input, + encoding=encoding, + api_base=api_base, + logging_obj=logging, + optional_params=optional_params, + model_response= EmbeddingResponse() + ) elif custom_llm_provider == "ollama": if aembedding == True: response = ollama.ollama_aembeddings(