import ast
import os

from sources.representation.ARParser import ast_parse_python_version_independent, open_file_with_encoding

PY_EXTENSION = ".py"
INIT_FILE = "__init__.py"


def get_imports(file_path):
    with open_file_with_encoding(file_path) as f:
        content = f.read()
        root = ast_parse_python_version_independent(content, file_path)
        visitor = ImportVisitor()
        visitor.visit(root)
        return visitor.modules_to_level_and_names


# noinspection PyPep8Naming
class ImportVisitor(ast.NodeVisitor):

    def __init__(self):
        super(ImportVisitor, self).__init__()
        self.modules_to_level_and_names = {}

    def visit_Import(self, node):
        names = node.names
        for name in names:
            module = name.name
            if module in self.modules_to_level_and_names:
                self.modules_to_level_and_names[module].append((0, ""))
            else:
                self.modules_to_level_and_names[module] = [(0, "")]

    def visit_ImportFrom(self, node):
        if node.module is None:
            if node.names:
                for name in node.names:
                    module = name.name
                    if module not in self.modules_to_level_and_names:
                        self.modules_to_level_and_names[module] = []
                    self.modules_to_level_and_names[module].append((node.level, ""))
        else:
            module = node.module
            from_names = [(node.level, alias.name) for alias in node.names]
            if module not in self.modules_to_level_and_names:
                self.modules_to_level_and_names[module] = []
            self.modules_to_level_and_names[module].extend(from_names)


def get_objects(root_file_path):
    path = os.path.dirname(root_file_path)
    recursive_imports_worklist = [root_file_path]
    recursive_imports_visited = {root_file_path}
    objects = []
    while recursive_imports_worklist:
        file_path = recursive_imports_worklist.pop()
        with open_file_with_encoding(file_path) as f:
            content = f.read()
            file_ast = ast_parse_python_version_independent(content, file_path)
            visitor = ObjectsVisitor()
            visitor.visit(file_ast)
            objects.extend(visitor.objects)
            for import_file in visitor.recursive_imports:
                file_name = os.path.join(get_real_path(path, import_file[0], import_file[1]) + PY_EXTENSION)
                package_init_file = os.path.join(get_real_path(path, import_file[0], import_file[1]), INIT_FILE)
                if not (file_name in recursive_imports_visited) and os.path.exists(file_name):
                    recursive_imports_worklist.append(file_name)
                    recursive_imports_visited.add(file_name)
                elif not (package_init_file in recursive_imports_visited) and os.path.exists(package_init_file):
                    recursive_imports_worklist.append(package_init_file)
                    recursive_imports_visited.add(package_init_file)
    return set(objects)


def get_real_path(path, import_module, level):
    import_module = get_module_dir_structure(import_module)
    if level > 1:
        dir_leading_dots = [".." for _ in range(level - 1)]
        dir_leading_dots = dir_leading_dots + [import_module]
        import_module = os.path.realpath(os.path.join(path, *dir_leading_dots))
    else:
        import_module = os.path.realpath(os.path.join(path, import_module))
    return import_module


def get_module_dir_structure(module):
    split_by_dir = module.split(".")
    return os.path.join(*split_by_dir)


# noinspection PyPep8Naming
class ObjectsVisitor(ast.NodeVisitor):

    def __init__(self):
        super(ObjectsVisitor, self).__init__()
        self.objects = set()
        self.recursive_imports = []

    def visit_ImportFrom(self, node):
        for alias in node.names:
            if alias.asname:
                self.objects.add(alias.asname)
            elif alias.name != "*":
                self.objects.add(alias.name)
            # elif node.module and "." not in node.module:
            #     self.recursive_imports.append([node.module, node.level])
            elif node.module:
                self.recursive_imports.append([node.module, node.level])

    def visit_FunctionDef(self, node):
        self.objects.add(node.name)

    def visit_ClassDef(self, node):
        self.objects.add(node.name)

    def visit_Assign(self, node):
        for target in node.targets:
            if type(target) is ast.Name:
                self.objects.add(target.id)
