Python jupyter notebook 自定义魔术方法

2023-12-23 21:54:06
import pymysql
import json
from IPython.display import HTML, JSON
from IPython.core import magic_arguments
from IPython.core.magic import (Magics, magics_class, line_magic, cell_magic, line_cell_magic)


@magics_class
class MagicSql(Magics):
    """

    """

    def __init__(self):
        super(MagicSql, self).__init__(shell=None)
        self.host = "127.0.0.1"
        self.user = "root"
        self.password = "123456"
        self.database = None
        # 常用mysql8的操作关键词
        self.words = {'call', 'add', 'dump', 'order', 'begin', 'revive',
                      'stop', 'release', 'execute', 'repair', 'ignore', 'scan',
                      'lock', 'create', 'commit', 'compress', 'flush', 'handler', 'grant',
                      'truncate', 'select', 'show', 'merge', 'rename', 'disable', 'remove',
                      'search', 'use', 'write', 'change', 'checksum', 'replace', 'purge',
                      'backup', 'start', 'validate', 'alter', 'rebuild', 'enable', 'kill',
                      'resolve', 'join', 'prepare', 'compare', 'check', 'load', 'coalesce',
                      'set', 'optimize', 'verify', 'revoke', 'split', 'import', 'unlock',
                      'reset', 'copy', 'upgrade', 'insert', 'update', 'delete', 'modify',
                      'explain', 'rollback', 'end', 'group', 'drop', 'savepoint', 'analyze'
                      }
        # 执行 fetchall 的 命令前缀
        self.fetchall_set = {"select", "show"}

    def get_conn(self):
        """
        获取连接
        :return:
        """
        try:
            conn = pymysql.connect(host=self.host, user=self.user, password=self.password, database=self.database)
        except:
            conn = pymysql.connect(host="192.168.2.117", user="root", password="123456", database=self.database)

        return conn

    def get_new_cell(self, cell):
        """
        初步处理 去除包含在三个单/双引号中的数据
        :param cell:
        :return:
        """
        new_cell = []
        in_quotes = False
        for line in cell.split('\n'):
            if not in_quotes:
                new_cell.append(line)

            if '"""' in line or "'''" in line:
                in_quotes = not in_quotes

        return "\n".join(new_cell)

    def anlise_setting_args(self, args):
        """

        :param args:
        :return:
        """
        # # 每次连接更新
        # self.database = setting_args.database
        # 重启kernel后更新
        if args.database is not None:
            self.database = args.database
        if args.sql is not None:
            line = args.sql.strip('"').strip("'")
            return line
        else:
            return None

    def anlise_line_cell(self, line, cell):
        """

        :param args:
        :return:
        """

        args = []
        if line:
            new_data = [line.strip("\n")]
            if args:
                args.extend(new_data)
            else:
                args = new_data
        if cell:
            cell = self.get_new_cell(cell)
            new_data = [i for i in cell.split("\n") if i.strip() and i.split(" ")[0].lower() in self.words]
            if args:
                args.extend(new_data)
            else:
                args = new_data
        return args

    def run_sql(self, args):
        """

        :param args:
        :return:
        """
        data = []

        conn = self.get_conn()
        cursor = conn.cursor()
        for sql in args:
            try:
                status = cursor.execute(sql)
                if sql.lower().split(" ")[0] in self.fetchall_set:
                    result = cursor.fetchall()
                else:
                    result = [status]
                    # print(result)
                    conn.commit()
            except Exception as e:
                conn.rollback()
                result = [f"异常:{e}"]
                print(e)
            finally:
                data.append({sql: {"status": status, "result": result}})
        conn.close()
        return data

    def format_out(self, result):
        """

        :param result:
        :return:
        """
        # # 格式不好看
        # return result
        #
        # # 不支持 datetime类型
        # return JSON(json.dumps(result, indent=4))

        return HTML('<pre>{}</pre>'.format(result))

    @magic_arguments.magic_arguments()
    # 配置连接的database
    @magic_arguments.argument(
        "-d", "--database", dest="database", default=None
    )
    # 行sql
    @magic_arguments.argument(
        "-s", "--sql", dest="sql", default=None, type=str
    )
    # 行/块
    @line_cell_magic
    def mysql(self, line=None, cell=None):
        """

        :param line:
        :param cell:
        :return:
        """
        # 获取配置参数
        setting_args = magic_arguments.parse_argstring(
            self.mysql, line)
        # 配置参数解析
        line = self.anlise_setting_args(setting_args)
        # 行/块参数解析
        args = self.anlise_line_cell(line, cell)
        # 返回结果
        result = self.run_sql(args)

        return self.format_out(result)


ipy = get_ipython()
ipy.register_magics(MagicSql())

将这些代码放入到一个.py文件中,然后把文件移动到 ~/.ipython/profile_default/startup/目录下,重启kernel即可实现每次重启自动加载(必须是ipython kernel)

文章来源:https://blog.csdn.net/CXY00000/article/details/135173631
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。