清华开源ChatGPT自动编程ChatDev项目codes.py代码解读

2023-12-31 08:41:30
  • 这段代码定义了一个Codes类,这个类是用于管理生成的代码的类,它可以根据LLM的回复来提取、格式化、更新和保存代码。
  • Codes类的__init__()方法是类的构造函数,它接受一个参数generated_content,表示LLM的回复内容。它首先初始化了以下几个属性:
    • self.directory: 一个字符串,表示代码保存的目录。
    • self.version: 一个浮点数,表示代码的版本号。
    • self.generated_content: 一个字符串,表示LLM的回复内容。
    • self.codebooks: 一个字典,表示代码的集合,以文件名为键,以代码内容为值。
  • 然后它定义了两个内部函数extract_filename_from_line()和extract_filename_from_code(),用于从LLM的回复中提取文件名。这两个函数都接受一个参数lines或code,表示LLM的回复中的一部分内容。这两个函数都使用re模块来进行正则表达式匹配,并返回匹配到的文件名。如果没有匹配到文件名,则返回空字符串。
  • 接下来,如果generated_content参数不为空,则它使用re模块来遍历LLM的回复中包含在“`符号中的代码段,并将其保存在code变量中。然后它判断code变量是否包含”CODE”字符串,如果是,则跳过这个代码段,因为这是一个占位符。然后它调用extract_filename_from_line()函数来从LLM的回复中提取文件名,并将其保存在filename变量中。如果filename变量为空,则它判断code变量是否包含”__main__“字符串,如果是,则将filename变量赋值为”main.py”,因为这是主程序文件。如果filename变量仍然为空,则它调用extract_filename_from_code()函数来从code变量中提取文件名,并将其保存在filename变量中。最后它断言filename变量不为空,并将其作为键,将经过_format_code()方法格式化后的code变量作为值,添加到self.codebooks字典中。
  • _format_code()方法用于对代码进行格式化,例如去除多余的空行等。它接受一个参数code,表示代码内容。它首先使用splitlines()方法和join()方法来去除空行,并返回格式化后的代码。
  • _update_codes()方法用于更新生成的代码,根据新的LLM的回复来比较、修改和保存代码。它接受一个参数generated_content,表示新的LLM的回复内容。它首先创建一个新的Codes对象new_codes,并将generated_content作为参数传递给其构造函数。然后它导入difflib模块,用于进行文本比较。接着它遍历new_codes对象中的self.codebooks字典,对于每个键值对(即文件名和代码内容),它判断是否存在于self.codebooks字典中,或者是否与self.codebooks字典中相同键对应的值不同。如果是,则表示需要更新代码,并执行以下操作:
    • 创建一个字符串update_codes_content,并赋值为”[Update Codes]\n\n”,表示开始更新代码。
    • 在update_codes_content字符串后面追加”{} updated.\n”.format(key),表示更新了哪个文件。
    • 创建两个字符串old_codes_content和new_codes_content,并分别赋值为self.codebooks字典中相同键对应的值(即旧代码)和new_codes对象中相同键对应的值(即新代码)。如果self.codebooks字典中不存在相同键,则将old_codes_content赋值为”# None”。
    • 使用splitlines()方法将old_codes_content和new_codes_content分割成行列表,并分别赋值给lines_old和lines_new。
    • 使用difflib.unified_diff()函数来生成两个行列表之间的差异,并返回一个生成器对象unified_diff。
    • 使用join()方法将unified_diff生成器对象转换为一个字符串,并赋值给unified_diff。
    • 在update_codes_content字符串后面追加”\n\n” + “””“` ‘’’

‘’’\n”“” + unified_diff + “\n“`”,表示显示代码的差异。 – 调用utils.log_and_print_online()函数来将update_codes_content字符串记录到日志文件中,并打印出来。 – 将new_codes对象中相同键对应的值赋值给self.codebooks字典中相同键对应的值,表示更新代码。

  • _rewrite_codes()方法用于重写代码,根据self.codebooks字典中的内容来修改和保存代码。它接受一个参数git_management,表示是否进行Git管理。它首先获取self.directory属性,表示代码保存的目录,并创建一个字符串rewrite_codes_content,用于记录重写过程。然后它判断目录是否存在并且不为空,如果是,则将self.version属性加一,表示代码的版本号增加。如果目录不存在,则使用os模块的mkdir()函数来创建目录,并在rewrite_codes_content字符串后面追加”{} Created\n”.format(directory),表示创建了目录。
  • 接着它遍历self.codebooks字典中的键值对(即文件名和代码内容),对于每个键值对,它使用os模块的join()函数来拼接目录和文件名,得到文件路径,并将其保存在filepath变量中。然后它使用open()函数和write()方法来打开并写入文件,并在rewrite_codes_content字符串后面追加os.path.join(directory, filename) + ” Wrote\n”,表示写入了文件。
  • 如果git_management参数为真,则表示需要进行Git管理,它会使用os模块的system()函数来执行一些Git命令,例如初始化仓库、添加文件、提交更改等,并将self.version属性作为提交信息。
  • 最后它调用utils.log_and_print_online()函数来将rewrite_codes_content字符串记录到日志文件中,并打印出来。
  • _get_codes()方法用于获取代码,根据self.codebooks字典中的内容来生成一个字符串,表示代码的集合。它首先创建一个空字符串content,然后遍历self.codebooks字典中的键值对(即文件名和代码内容),对于每个键值对,它在content字符串后面追加”{}\n{}\n{}\n\n\n”.format(filename, “python” if filename.endswith(“.py”) else filename.split(“.”)[-1], self.codebooks[filename]),表示显示文件名和代码内容,并根据文件扩展名来指定语言类型。最后它返回content字符串。
  • _load_from_hardware()方法用于从硬盘中加载代码,根据给定的目录来读取并保存代码。它接受一个参数directory,表示代码所在的目录。它首先断言目录中存在以.py结尾的文件,然后使用os模块的walk()函数来遍历目录中的所有文件。对于每个文件,如果文件以.py结尾,则使用open()函数和read()方法来读取文件内容,并将其保存在code变量中。然后将经过_format_code()方法格式化后的code变量作为值,将文件名作为键,添加到self.codebooks字典中。最后调用utils.log_and_print_online()函数来记录并打印”{} files read from {}”.format(len(self.codebooks.keys()), directory),表示从目录中读取了多少个文件。

codes.py的代码如下:

import os
import re

from chatdev.utils import log_and_print_online
import difflib

class Codes:
    def __init__(self, generated_content=""):
        self.directory: str = None
        self.version: float = 1.0
        self.generated_content: str = generated_content
        self.codebooks = {}

        def extract_filename_from_line(lines):
            file_name = ""
            for candidate in re.finditer(r"(\w+\.\w+)", lines, re.DOTALL):
                file_name = candidate.group()
                file_name = file_name.lower()
            return file_name

        def extract_filename_from_code(code):
            file_name = ""
            regex_extract = r"class (\S+?):\n"
            matches_extract = re.finditer(regex_extract, code, re.DOTALL)
            for match_extract in matches_extract:
                file_name = match_extract.group(1)
            file_name = file_name.lower().split("(")[0] + ".py"
            return file_name

        if generated_content != "":
            regex = r"(.+?)\n```.*?\n(.*?)```"
            matches = re.finditer(regex, self.generated_content, re.DOTALL)
            for match in matches:
                code = match.group(2)
                if "CODE" in code:
                    continue
                group1 = match.group(1)
                filename = extract_filename_from_line(group1)
                if "__main__" in code:
                    filename = "main.py"
                if filename == "":  # post-processing
                    filename = extract_filename_from_code(code)
                assert filename != ""
                if filename is not None and code is not None and len(filename) > 0 and len(code) > 0:
                    self.codebooks[filename] = self._format_code(code)

    def _format_code(self, code):
        code = "\n".join([line for line in code.split("\n") if len(line.strip()) > 0])
        return code

    def _update_codes(self, generated_content):
        new_codes = Codes(generated_content)
        differ = difflib.Differ()
        for key in new_codes.codebooks.keys():
            if key not in self.codebooks.keys() or self.codebooks[key] != new_codes.codebooks[key]:
                update_codes_content = "**[Update Codes]**\n\n"
                update_codes_content += "{} updated.\n".format(key)
                old_codes_content = self.codebooks[key] if key in self.codebooks.keys() else "# None"
                new_codes_content = new_codes.codebooks[key]

                lines_old = old_codes_content.splitlines()
                lines_new = new_codes_content.splitlines()

                unified_diff = difflib.unified_diff(lines_old, lines_new, lineterm='', fromfile='Old', tofile='New')
                unified_diff = '\n'.join(unified_diff)
                update_codes_content = update_codes_content + "\n\n" + """```
'''

'''\n""" + unified_diff + "\n```"

                log_and_print_online(update_codes_content)
                self.codebooks[key] = new_codes.codebooks[key]

    def _rewrite_codes(self, git_management) -> None:
        directory = self.directory
        rewrite_codes_content = "**[Rewrite Codes]**\n\n"
        if os.path.exists(directory) and len(os.listdir(directory)) > 0:
            self.version += 1.0
        if not os.path.exists(directory):
            os.mkdir(self.directory)
            rewrite_codes_content += "{} Created\n".format(directory)

        for filename in self.codebooks.keys():
            filepath = os.path.join(directory, filename)
            with open(filepath, "w", encoding="utf-8") as writer:
                writer.write(self.codebooks[filename])
                rewrite_codes_content += os.path.join(directory, filename) + " Wrote\n"

        if git_management:
            if self.version == 1.0:
                os.system("cd {}; git init".format(self.directory))
            os.system("cd {}; git add .".format(self.directory))
            os.system("cd {}; git commit -m \"{}\"".format(self.directory, self.version))

        log_and_print_online(rewrite_codes_content)

    def _get_codes(self) -> str:
        content = ""
        for filename in self.codebooks.keys():
            content += "{}\n```{}\n{}\n```\n\n".format(filename,
                                                       "python" if filename.endswith(".py") else filename.split(".")[
                                                           -1], self.codebooks[filename])
        return content

    def _load_from_hardware(self, directory) -> None:
        assert len([filename for filename in os.listdir(directory) if filename.endswith(".py")]) > 0
        for root, directories, filenames in os.walk(directory):
            for filename in filenames:
                if filename.endswith(".py"):
                    code = open(os.path.join(directory, filename), "r", encoding="utf-8").read()
                    self.codebooks[filename] = self._format_code(code)
        log_and_print_online("{} files read from {}".format(len(self.codebooks.keys()), directory))

文章来源:https://blog.csdn.net/linweidong/article/details/135312960
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。