import ast
import importlib
import inspect
import sys
import os
import re
from argparse import ArgumentParser
from typing import List
import traits.api

function_signature_pattern = re.compile(r'^\w+\(.*\).*\n*')
bang_object_reference = re.compile(r':py:(\w+):`!([\w.]+)`')
tilde_object_reference = re.compile(r':py:(\w+):`\~ovito\.(?:\w+\.)+(\w+)`')
named_reference = re.compile(r':ref:`([\w\s\.\-\(\\)]+)\s<[\w\.\-_:]+>`')
named_object_reference = re.compile(r':py:\w+:`([\w\s\.\(\)]+)\s<[\w\.\-]+>`')
literal_include_directive_no_lines = re.compile(r'\.\.\ literalinclude::\ ([\./\w]+)(\s*\n\s*\n|$)')
literal_include_directive_with_lines = re.compile(r'\.\.\ literalinclude::\ ([\./\w]+)\s*\n\s+:lines:\s*(\d+)\-(\d+)(\s*|$)')
literal_include_directive_with_lines_1 = re.compile(r'\.\.\ literalinclude::\ ([\./\w]+)\s*\n\s+:lines:\s*(\d+)\-(\s*|$)')
bold_formatting = re.compile(r'\*\*([\w\s\-:]+)\*\*')
rst_anchors = re.compile(r'.. _[\w\.\-]+:\s*\n')
version_added_directive = re.compile(r'\.\.\ versionadded::.+\n*')
default_value_field = re.compile(r':Default:\s+')
autodoc_skip = re.compile(r'AUTODOC_SKIP_MEMBER')
note_directive = re.compile(r'\.\.\ note::(\s+\n)*')

def cleanup_docstring(doc: str, func_name: str, base_include_dir: str) -> str:

    # Discard docstrings containing the marker 'AUTODOC_SKIP_MEMBER':
    if autodoc_skip.search(doc):
        return ''

    # Remove function signature in the first line of the docstring comment.
    if func_name and doc.startswith(func_name + '('):
        m = function_signature_pattern.search(doc)
        doc = doc[:m.start()] + doc[m.end():]

    # Remove rst anchors:
    m = rst_anchors.search(doc)
    while m:
        doc = doc[:m.start()] + doc[m.end():]
        m = rst_anchors.search(doc)

    # Remove versionadded directives:
    m = version_added_directive.search(doc)
    while m:
        doc = doc[:m.start()] + doc[m.end():]
        m = version_added_directive.search(doc)

    # Remove note directives:
    m = note_directive.search(doc)
    while m:
        doc = doc[:m.start()] + doc[m.end():]
        m = note_directive.search(doc)

    # Remove first colon from :Default: field:
    m = default_value_field.search(doc)
    while m:
        doc = doc[:m.start()] + 'Default: ' + doc[m.end():]
        m = default_value_field.search(doc)

    # Remove rst bold formatting, e.g. **Caption** -> Caption
    m = bold_formatting.search(doc)
    while m:
        doc = doc[:m.start()] + m[1] + doc[m.end():]
        m = bold_formatting.search(doc)

    # Remove ! from references to Python objects, e.g. :py:class:`!DataObject` -> :py:class:`DataObject`
    m = bang_object_reference.search(doc)
    while m:
        doc = doc[:m.start()] + ':py:{}:`{}`'.format(m[1], m[2]) + doc[m.end():]
        m = bang_object_reference.search(doc)

    # Shorten references to Python objects, e.g. :py:class:`~ovito.data.ElementType`` -> :py:class:`ElementType`
    m = tilde_object_reference.search(doc)
    while m:
        doc = doc[:m.start()] + ':py:{}:`{}`'.format(m[1], m[2]) + doc[m.end():]
        m = tilde_object_reference.search(doc)

    # Remove <> part of references, e.g. :ref:`particle property <particle-properties-list>` -> `particle property`
    m = named_reference.search(doc)
    while m:
        doc = doc[:m.start()] + m[1] + doc[m.end():]
        m = named_reference.search(doc)

    # Remove <> part of Python object references, e.g. :py:attr:`bond_types <ovito.data.Bonds.bond_types>` -> `bond_types`
    m = named_object_reference.search(doc)
    while m:
        doc = doc[:m.start()] + '`{}`'.format(m[1]) + doc[m.end():]
        m = named_object_reference.search(doc)

    # Process literalinclude directives (without :lines: specifier).
    m = literal_include_directive_no_lines.search(doc)
    while m:
        filename = os.path.join(base_include_dir, m[1])
        with open(filename, 'r') as fin:
            code = fin.readlines()
        if not code[-1].endswith('\n'): code[-1] = code[-1] + '\n'
        doc = doc[:m.start()] + '```python\n  ' + '  '.join(code) + '```\n\n' + doc[m.end():]
        m = literal_include_directive_no_lines.search(doc)

    # Process literalinclude directives (with :lines: a-b specifier).
    m = literal_include_directive_with_lines.search(doc)
    while m:
        filename = os.path.join(base_include_dir, m[1])
        with open(filename, 'r') as fin:
            code = fin.readlines()
        line_from = int(m[2])
        line_to = int(m[3])
        code = code[line_from-1:line_to]
        doc = doc[:m.start()] + '```python\n  ' + '  '.join(code) + '```\n\n' + doc[m.end():]
        m = literal_include_directive_with_lines.search(doc)

    # Process literalinclude directives (with :lines: a- specifier).
    m = literal_include_directive_with_lines_1.search(doc)
    while m:
        filename = os.path.join(base_include_dir, m[1])
        with open(filename, 'r') as fin:
            code = fin.readlines()
        line_from = int(m[2])
        code = code[line_from-1:]
        if not code[-1].endswith('\n'): code[-1] = code[-1] + '\n'
        doc = doc[:m.start()] + '```python\n  ' + '  '.join(code) + '```\n\n' + doc[m.end():]
        m = literal_include_directive_with_lines_1.search(doc)

    return doc.strip()

class RewriteNodes(ast.NodeTransformer):

    def __init__(self, obj, base_include_dir: str):
        self.obj = obj
        self.base_include_dir = base_include_dir

    def visit_Module(self, node):
        all = getattr(self.obj, "__all__", None)
        if all:
            all_list = []
            for id in all:
                all_list.append(ast.Constant(value=id))
            assign_node = ast.Assign(targets=[ast.Name(id='__all__')], value=ast.List(elts=all_list))
            node.body.insert(0, assign_node)
        if not ast.get_docstring(node):
            ds = inspect.getdoc(self.obj)
            if ds: node.body.insert(0, ast.Expr(ast.Constant(cleanup_docstring(ds, '', self.base_include_dir))))
        self.generic_visit(node)
        return node

    def visit_FunctionDef(self, node):
        if not ast.get_docstring(node):
            ds = inspect.getdoc(getattr(self.obj, node.name))
            if ds: node.body.insert(0, ast.Expr(ast.Constant(cleanup_docstring(ds,node.name, self.base_include_dir))))
        if node.name == "__init__":
            for i in range(len(node.body)-1, -1, -1):
                child_node = node.body[i]
                if isinstance(child_node, ast.AnnAssign) and isinstance(child_node.target, ast.Attribute) and child_node.target.value.id == "self":
                    field_name = child_node.target.attr
                    if isinstance(self.obj, traits.api.MetaHasTraits) and field_name in self.obj.__class_traits__:
                        # TODO: Find out how to access the docstring of a class traits field.
                        ds = None
                    else:
                        ds = inspect.getdoc(getattr(self.obj, field_name))
                    if ds: node.body.insert(i+1, ast.Expr(ast.Constant(cleanup_docstring(ds, '', self.base_include_dir))))
        return node

    def visit_ClassDef(self, node: ast.ClassDef):
        clazz = getattr(self.obj, node.name, None)
        if clazz:
            if not ast.get_docstring(node):
                ds = inspect.getdoc(clazz)
                if ds: node.body.insert(0, ast.Expr(ast.Constant(cleanup_docstring(ds, node.name, self.base_include_dir))))

            # Process dataclass fields
            for i in range(len(node.body)-1, -1, -1):
                body_node = node.body[i]
                if isinstance(body_node, ast.AnnAssign) and body_node.simple:
                    field_name = body_node.target.id
                    if isinstance(clazz, traits.api.MetaHasTraits) and field_name in clazz.__class_traits__:
                        # TODO: Find out how to access the docstring of a class traits field.
                        ds = None
                    else:
                        ds = inspect.getdoc(getattr(clazz, field_name))
                    if ds: node.body.insert(i+1, ast.Expr(ast.Constant(cleanup_docstring(ds, '', self.base_include_dir))))
#                print(ast.dump(node, indent=4))

            RewriteNodes(clazz, self.base_include_dir).generic_visit(node)
        return node

def process_stub_file(stub_file_in: str, stub_file_out: str, module_name: str, base_include_dir: str) -> None:
    print(f"Processing stub file {stub_file_in} (module name: {module_name})")
    runtime_module = importlib.import_module(module_name)
    with open(stub_file_in, 'r', encoding="utf-8") as fin:
        source = fin.read()
    stub_module = ast.parse(source, filename=stub_file_in, type_comments=True)
    RewriteNodes(runtime_module, base_include_dir).visit(stub_module)
    ast.fix_missing_locations(stub_module)
    with open(stub_file_out, 'w', encoding="utf-8") as fout:
        fout.write(ast.unparse(stub_module))

def main(args: List[str] | None = None):
    parser = ArgumentParser(prog='add_stub_docstrings', description="Copies docstrings from the runtime module to a stub file")
    parser.add_argument("rootdir", nargs="+", metavar="root_dirs", type=str, help="The root directory")
    parser.add_argument("--log-level", default="INFO", help="Set output log level")
    parser.add_argument("--include-dir", metavar='PATH', dest="include_dir", type=str, default="", help="The Sphinx include base directory")
    sys_args = parser.parse_args(args or sys.argv[1:])

    if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 9):
        print("Skipping docstring integration, because Python interpreter is too old: ", sys.version)
        return

    for root_dir in sys_args.rootdir:
        root_dir = root_dir.replace('\\', '/')
        for dirpath, dirnames, filenames in os.walk(root_dir):
            base_index = dirpath.replace('\\', '/').rfind("/ovito")
            assert base_index != -1
            module_name = dirpath[base_index+1:].replace('/', '.').replace('\\', '.')
            for filename in filenames:
                if filename.endswith('.pyi.in'):
                    full_module_name = module_name
                    if filename != '__init__.pyi.in': full_module_name += '.' + filename[:-4]
                    filename_in = os.path.join(dirpath, filename)
                    filename_out = filename_in[:-3]
                    process_stub_file(filename_in, filename_out, full_module_name, sys_args.include_dir)
    print("Done")

if __name__ == "__main__":
    main()