"""
Support code for linetap.

This, in particular, contains bindings to the InChI library.
Use getILib() to obtain the bindings, and see the _InChIBindings
docstrings on what you can all.  It probably helps to be aware of
https://www.inchi-trust.org/wp/download/104/InChI_API_Reference.pdf
although the binding is substantially different.
"""

import ctypes
import functools
import warnings
from ctypes.util import find_library

from gavo import base
from gavo import utils


class InChIError(utils.Error):
	"""is raised when a libinchi function indicates an error has
	occurred.

	This is usually rather cryptic but ought to return a numeric
	code that can be decoded using the libinchi API document.
	"""


# The following values should really be parsed from /usr/include/inchi_api.h
# Maximum bonds per atom
MAXVAL = 20
ATOM_EL_LEN = 6

class _inchi_Atom(ctypes.Structure):
	_fields_ = [
		("x", ctypes.c_double),
		("y", ctypes.c_double),
		("z", ctypes.c_double),
		("neighbor", ctypes.c_short*MAXVAL),
		("bond_type", ctypes.c_byte*MAXVAL),
		("bond_stereo", ctypes.c_char*MAXVAL),
		("elname", ctypes.c_char*ATOM_EL_LEN),
		("num_bonds", ctypes.c_short),
		("num_iso_H", ctypes.c_char*4),
		("isotopic_mass", ctypes.c_short),
		("radical", ctypes.c_char),
		("charge", ctypes.c_byte)]

	def connect(self, otherAtomIndex, bond_type):
		"""Adds a connection to otherAtomIndex.

		User code will probably rather use _inchi_Input's connect method,
		as it usually cannot know otherAtomIndex.
		"""
		if self.num_bonds==MAXVAL:
			raise InChIError("Too many connections")
		self.neighbor[self.num_bonds] = otherAtomIndex
		self.bond_type[self.num_bonds] = bond_type
		self.num_bonds += 1


class _inchi_Input(ctypes.Structure):
	_fields_ = [
		("atom", ctypes.POINTER(_inchi_Atom)),
		("stereo0D", ctypes.c_void_p),
		("szOptions", ctypes.c_char_p),
		("num_atoms", ctypes.c_int),
		("num_stereo0D", ctypes.c_int),]

	def connect(self, atomIndex1, atomIndex2, bond_type=1):
		"""add a bond between the atom at index1 and index2.
		"""
		if (not 0<=atomIndex1<self.num_atoms
				or not 0<=atomIndex2<self.num_atoms):
			raise InChIError("Atom indexes must be < {} here".format(
				self.num_atoms))

		self.atom[atomIndex1].connect(atomIndex2, bond_type)
		self.atom[atomIndex2].connect(atomIndex1, bond_type)


class _inchi_Output(ctypes.Structure):
	_fields_ = [
		("szInChI", ctypes.c_char_p),
		("szAuxInfo", ctypes.c_char_p),
		("szMessage", ctypes.c_char_p),
		("szLog", ctypes.c_char_p)]


class _inchi_InputINCHI(ctypes.Structure):
	_fields_ = [
		("szInChI", ctypes.c_char_p),
		("szOptions", ctypes.c_char_p),
	]


class _InChIBinding:
	"""A class encapsulating our binding to libinchi.

	This is intended to be a singleton; users ought to obtain it through
	the getInChI function below.

	This will raise a ReportableError if the inchi C library cannot be found.
	"""
	_libname = "inchi"

	def __init__(self):
		self.libPath = find_library(self._libname)
		if self.libPath is None:
			raise base.ReportableError("No libinchi found.",
				hint="On Debian systems, install libinchi1 to get it.")

		self.lib = ctypes.CDLL(self.libPath)

		for wrappedName, argtypes in [
				("CheckINCHI", [ctypes.c_char_p, ctypes.c_int]),
				("GetStdINCHIKeyFromStdINCHI",
					[ctypes.c_char_p, ctypes.c_char_p]),
				("GetINCHI", 
					[ctypes.POINTER(_inchi_Input), ctypes.POINTER(_inchi_Output)]),
				("GetStdINCHI", 
					[ctypes.POINTER(_inchi_Input), ctypes.POINTER(_inchi_Output)]),
				("GetINCHIfromINCHI", 
					[ctypes.POINTER(_inchi_InputINCHI), ctypes.POINTER(_inchi_Output)]),
				("FreeStdINCHI", [ctypes.POINTER(_inchi_Output)]),
			]:
			setattr(self, "_"+wrappedName, getattr(self.lib, wrappedName))
			getattr(self, "_"+wrappedName).argtypes = argtypes
		
	def checkInChI(self, inchi, strict=1):
		"""returns 0 if inchi is ok, something else (cf. API docs) otherwise.
		"""
		return self._CheckINCHI(utils.bytify(inchi), ctypes.c_int(strict))

	def getInChIKey(self, inChI):
		"""returns an InChIKey for a standard InChI.
		"""
		buf = ctypes.create_string_buffer(28)
		retval = self._GetStdINCHIKeyFromStdINCHI(
			utils.bytify(inChI), buf)
		if retval!=0:
			raise InChIError(f"Could not generate InChIKey: {retval}")
		return utils.debytify(buf.value)

	def getInput(self, atoms):
		"""returns an _inchi_Input for a list of inchi_Atoms.

		This is what you need to call to generate an InChI(Key).  Essentially,
		you create atoms using getAtom, put them into a list, pass them
		to getInput and then call connect(index1, index2) to make the
		molecule connections.
		"""
		atomsArray = (_inchi_Atom*len(atoms))(*atoms)
		return _inchi_Input(
			atom=atomsArray,
			stereo0D=None,
			szOptions=b"-DoNotAddH",
			num_atoms=len(atoms),
			num_stereo0D=0)
	
	def getAtom(self, elname, isotopic_mass=0, charge=0):
		"""returns a new atom.

		Only pass isotopic_mass if you actually care about isotopism.
		"""
		neighbor = (ctypes.c_short*MAXVAL)(*([0]*MAXVAL))
		bond_type = (ctypes.c_byte*MAXVAL)(*([0]*MAXVAL))

		return _inchi_Atom(
			x=0, y=0, z=0,
			neighbor=neighbor,
			bond_type=bond_type, # (i.e., single, double... binding)
			bond_stereo=b"", # (let's see when we want it)
			elname=utils.bytify(elname),
			num_bonds=0,
			num_iso_H=b"\0\0\0\0", # we don't support H isotopes at this point, either
			isotopic_mass=isotopic_mass,
			radical=0,
			charge=ctypes.c_byte(charge))

	def _interpretICReturnCode(self, retval, output):
		try:
			if retval==1:
				warnings.warn("InChI warning: "+utils.debytify(output.szMessage))
			elif retval!=0:
				raise InChIError(utils.debytify(output.szMessage 
						or "(Unclear problem)"),
					hint=utils.debytify(output.szLog))

			return utils.debytify(output.szInChI)
		finally:
			self._FreeStdINCHI(output)

	def getInChI(self, inchiInput):
		"""returns the InChI of the molecule descripted through inchiInput
		as a string.
		"""
		output = _inchi_Output()
		res = self._GetStdINCHI(inchiInput, output)
		return self._interpretICReturnCode(res, output)

	def normalizeInChI(self, inChI):
		"""returns a string inChI in normalised form.
		"""
		output = _inchi_Output()
		res = self._GetINCHIfromINCHI(
			_inchi_InputINCHI(utils.bytify(inChI)), output)
		return self._interpretICReturnCode(res, output)


@functools.lru_cache(None)
def getILib():
	"""returns a shallow binding to the InChI library.

	This will always return the same object if called multiple times.
	It will raise a ReportableError if the library is not available.
	"""
	return _InChIBinding()
