xat/xatapi.py

271 líneas
8.1 KiB
Python

import os
import requests
from requests.models import urlencode
import pdb
###
# Dict helper class.
# Defined at top level so it can be pickled.
###
class AttribAccessDict(dict):
def __getattr__(self, attr):
if attr in self:
return self[attr]
else:
raise AttributeError("Attribute not found: " + str(attr))
def __setattr__(self, attr, val):
if attr in self:
raise AttributeError("Attribute-style access is read only")
super(AttribAccessDict, self).__setattr__(attr, val)
class Xat:
name = 'OpenAI API Python wrapper'
def __init__(self, api_base_url=None, access_token=None, session=None, headers=None):
self.__xatbot_config_path = "config/xatbot.txt"
is_setup = self.__check_setup(self)
if is_setup:
self.api_base_url = self.__get_parameter("api_base_url", self.__xatbot_config_path)
self.access_token = self.__get_parameter("access_token", self.__xatbot_config_path)
else:
self.api_base_url, self.access_token = self.__setup(self)
self.headers = {"Authorization": "Bearer " + self.access_token} if self.access_token else {}
if session:
self.session = session
else:
self.session = requests.Session()
def completions(self, model=None, prompt=None):
data = {
'model': "text-davinci-003",
'prompt': prompt,
'max_tokens': 750,
'temperature': 0
}
endpoint = self.api_base_url + '/v1/completions'
response = self.__api_request('POST', endpoint, data)
response = self.__json_allow_dict_attrs(response.json())
return response
def models(self):
data = {}
endpoint = self.api_base_url + '/v1/models'
response = self.__api_request('GET', endpoint, data)
response = self.__json_allow_dict_attrs(response.json())
return response
def get_model(self, model=None):
data = {}
endpoint = f'{self.api_base_url}/v1/models/{model}'
response = self.__api_request('GET', endpoint, data)
response = self.__json_allow_dict_attrs(response.json())
return response
def create_image(self, model=None, prompt=None):
data = {
'prompt': prompt,
'n': 1,
'size': "512x512",
'response_format': "b64_json"
}
endpoint = f'{self.api_base_url}/v1/images/generations'
response = self.__api_request('POST', endpoint, data)
response = self.__json_allow_dict_attrs(response.json())
return response
@staticmethod
def __check_setup(self):
is_setup = False
if not os.path.isfile(self.__xatbot_config_path):
print(f"File {self.__xatbot_config_path} not found, running setup.")
else:
is_setup = True
return is_setup
@staticmethod
def __setup(self):
if not os.path.exists('config'):
os.makedirs('config')
self.api_base_url = input("Openai API url, in ex. 'https://api.openai.com': ")
self.access_token = input("Openai access token: ")
if not os.path.exists(self.__xatbot_config_path):
with open(self.__xatbot_config_path, 'w'): pass
print(f"{self.__xatbot_config_path} created!")
with open(self.__xatbot_config_path, 'a') as the_file:
print("Writing xatbot parameters to " + self.__xatbot_config_path)
the_file.write(f'api_base_url: {self.api_base_url}\n'+f'access_token: {self.access_token}\n')
return (self.api_base_url, self.access_token)
@staticmethod
def __get_parameter(parameter, file_path ):
with open( file_path ) as f:
for line in f:
if line.startswith( parameter ):
return line.replace(parameter + ":", "").strip()
print(f'{file_path} Missing parameter {parameter}')
sys.exit(0)
def __api_request(self, method, endpoint, data): #data={}):
response = None
try:
#kwargs = dict(data=data)
response = self.session.request(method, url = endpoint, headers = self.headers, json=data) #**kwargs)
except Exception as e:
raise OpenaiNetworkError(f"Could not complete request: {e}")
if response is None:
raise OpenaiIllegalArgumentError("Illegal request.")
if not response.ok:
try:
if isinstance(response, dict) and 'error' in response:
error_msg = response['error']
elif isinstance(response, str):
error_msg = response
else:
error_msg = None
except ValueError:
error_msg = None
if response.status_code == 404:
ex_type = OpenaiNotFoundError
if not error_msg:
error_msg = 'Endpoint not found.'
# this is for compatibility with older versions
# which raised OpenaiAPIError('Endpoint not found.')
# on any 404
elif response.status_code == 400:
return response
elif response.status_code == 401:
ex_type = OpenaiUnauthorizedError
elif response.status_code == 422:
return response
elif response.status_code == 500:
ex_type = OpenaiInternalServerError
elif response.status_code == 502:
ex_type = OpenaiBadGatewayError
elif response.status_code == 503:
#ex_type = OpenaiServiceUnavailableError
return response
elif response.status_code == 504:
ex_type = OpenaiGatewayTimeoutError
elif response.status_code >= 500 and \
response.status_code <= 511:
ex_type = OpenaiServerError
else:
ex_type = OpenaiAPIError
raise ex_type(
'Openai API returned error',
response.status_code,
response.reason,
error_msg)
return response
else:
return response
@staticmethod
def __json_allow_dict_attrs(json_object):
"""
Makes it possible to use attribute notation to access a dicts
elements, while still allowing the dict to act as a dict.
"""
if isinstance(json_object, dict):
return AttribAccessDict(json_object)
return json_object
##
# Exceptions
##
class OpenaiError(Exception):
"""Base class for Openai.py exceptions"""
class OpenaiIOError(IOError, OpenaiError):
"""Base class for Openai.py I/O errors"""
class OpenaiNetworkError(OpenaiIOError):
"""Raised when network communication with the server fails"""
pass
class OpenaiAPIError(OpenaiError):
"""Raised when the forgejo API generates a response that cannot be handled"""
pass
class OpenaiServerError(OpenaiAPIError):
"""Raised if the Server is malconfigured and returns a 5xx error code"""
pass
class OpenaiInternalServerError(OpenaiServerError):
"""Raised if the Server returns a 500 error"""
pass
class OpenaiBadGatewayError(OpenaiServerError):
"""Raised if the Server returns a 502 error"""
pass
class OpenaiServiceUnavailableError(OpenaiServerError):
"""Raised if the Server returns a 503 error"""
pass
class OpenaiGatewayTimeoutError(OpenaiServerError):
"""Raised if the Server returns a 504 error"""
pass
class OpenaiNotFoundError(OpenaiAPIError):
"""Raised when the forgejo API returns a 404 Not Found error"""
pass
class OpenaiUnauthorizedError(OpenaiAPIError):
"""Raised when the forgejo API returns a 401 Unauthorized error
This happens when an OAuth token is invalid or has been revoked,
or when trying to access an endpoint that can't be used without
authentication without providing credentials."""
pass