(feat) ollama use /api/chat

This commit is contained in:
ishaan-jaff 2023-12-25 14:29:10 +05:30
parent da4ec6c8b6
commit d85c19394f

View file

@ -132,10 +132,10 @@ def get_ollama_response(
model_response=None, model_response=None,
encoding=None, encoding=None,
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/chat"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/generate" url = f"{api_base}/api/chat"
## Load Config ## Load Config
config = litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
@ -329,75 +329,3 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e raise e
async def ollama_aembeddings(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
if api_base.endswith("/api/embeddings"):
url = api_base
else:
url = f"{api_base}/api/embeddings"
## Load Config
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
}
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
)
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
response = await session.post(url, json=data)
if response.status != 200:
text = await response.text()
raise OllamaError(status_code=response.status, message=text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = await response.json()
embeddings = response_json["embedding"]
## RESPONSE OBJECT
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = len(encoding.encode(prompt))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response