清华开源ChatGPT自动编程ChatDev项目codes.py代码解读
2023-12-31 08:41:30
- 这段代码定义了一个Codes类,这个类是用于管理生成的代码的类,它可以根据LLM的回复来提取、格式化、更新和保存代码。
- Codes类的__init__()方法是类的构造函数,它接受一个参数generated_content,表示LLM的回复内容。它首先初始化了以下几个属性:
- self.directory: 一个字符串,表示代码保存的目录。
- self.version: 一个浮点数,表示代码的版本号。
- self.generated_content: 一个字符串,表示LLM的回复内容。
- self.codebooks: 一个字典,表示代码的集合,以文件名为键,以代码内容为值。
- 然后它定义了两个内部函数extract_filename_from_line()和extract_filename_from_code(),用于从LLM的回复中提取文件名。这两个函数都接受一个参数lines或code,表示LLM的回复中的一部分内容。这两个函数都使用re模块来进行正则表达式匹配,并返回匹配到的文件名。如果没有匹配到文件名,则返回空字符串。
- 接下来,如果generated_content参数不为空,则它使用re模块来遍历LLM的回复中包含在“`符号中的代码段,并将其保存在code变量中。然后它判断code变量是否包含”CODE”字符串,如果是,则跳过这个代码段,因为这是一个占位符。然后它调用extract_filename_from_line()函数来从LLM的回复中提取文件名,并将其保存在filename变量中。如果filename变量为空,则它判断code变量是否包含”__main__“字符串,如果是,则将filename变量赋值为”main.py”,因为这是主程序文件。如果filename变量仍然为空,则它调用extract_filename_from_code()函数来从code变量中提取文件名,并将其保存在filename变量中。最后它断言filename变量不为空,并将其作为键,将经过_format_code()方法格式化后的code变量作为值,添加到self.codebooks字典中。
- _format_code()方法用于对代码进行格式化,例如去除多余的空行等。它接受一个参数code,表示代码内容。它首先使用splitlines()方法和join()方法来去除空行,并返回格式化后的代码。
- _update_codes()方法用于更新生成的代码,根据新的LLM的回复来比较、修改和保存代码。它接受一个参数generated_content,表示新的LLM的回复内容。它首先创建一个新的Codes对象new_codes,并将generated_content作为参数传递给其构造函数。然后它导入difflib模块,用于进行文本比较。接着它遍历new_codes对象中的self.codebooks字典,对于每个键值对(即文件名和代码内容),它判断是否存在于self.codebooks字典中,或者是否与self.codebooks字典中相同键对应的值不同。如果是,则表示需要更新代码,并执行以下操作:
- 创建一个字符串update_codes_content,并赋值为”[Update Codes]\n\n”,表示开始更新代码。
- 在update_codes_content字符串后面追加”{} updated.\n”.format(key),表示更新了哪个文件。
- 创建两个字符串old_codes_content和new_codes_content,并分别赋值为self.codebooks字典中相同键对应的值(即旧代码)和new_codes对象中相同键对应的值(即新代码)。如果self.codebooks字典中不存在相同键,则将old_codes_content赋值为”# None”。
- 使用splitlines()方法将old_codes_content和new_codes_content分割成行列表,并分别赋值给lines_old和lines_new。
- 使用difflib.unified_diff()函数来生成两个行列表之间的差异,并返回一个生成器对象unified_diff。
- 使用join()方法将unified_diff生成器对象转换为一个字符串,并赋值给unified_diff。
- 在update_codes_content字符串后面追加”\n\n” + “””“` ‘’’
‘’’\n”“” + unified_diff + “\n“`”,表示显示代码的差异。 – 调用utils.log_and_print_online()函数来将update_codes_content字符串记录到日志文件中,并打印出来。 – 将new_codes对象中相同键对应的值赋值给self.codebooks字典中相同键对应的值,表示更新代码。
- _rewrite_codes()方法用于重写代码,根据self.codebooks字典中的内容来修改和保存代码。它接受一个参数git_management,表示是否进行Git管理。它首先获取self.directory属性,表示代码保存的目录,并创建一个字符串rewrite_codes_content,用于记录重写过程。然后它判断目录是否存在并且不为空,如果是,则将self.version属性加一,表示代码的版本号增加。如果目录不存在,则使用os模块的mkdir()函数来创建目录,并在rewrite_codes_content字符串后面追加”{} Created\n”.format(directory),表示创建了目录。
- 接着它遍历self.codebooks字典中的键值对(即文件名和代码内容),对于每个键值对,它使用os模块的join()函数来拼接目录和文件名,得到文件路径,并将其保存在filepath变量中。然后它使用open()函数和write()方法来打开并写入文件,并在rewrite_codes_content字符串后面追加os.path.join(directory, filename) + ” Wrote\n”,表示写入了文件。
- 如果git_management参数为真,则表示需要进行Git管理,它会使用os模块的system()函数来执行一些Git命令,例如初始化仓库、添加文件、提交更改等,并将self.version属性作为提交信息。
- 最后它调用utils.log_and_print_online()函数来将rewrite_codes_content字符串记录到日志文件中,并打印出来。
- _get_codes()方法用于获取代码,根据self.codebooks字典中的内容来生成一个字符串,表示代码的集合。它首先创建一个空字符串content,然后遍历self.codebooks字典中的键值对(即文件名和代码内容),对于每个键值对,它在content字符串后面追加”{}\n
{}\n{}\n
\n\n”.format(filename, “python” if filename.endswith(“.py”) else filename.split(“.”)[-1], self.codebooks[filename]),表示显示文件名和代码内容,并根据文件扩展名来指定语言类型。最后它返回content字符串。 - _load_from_hardware()方法用于从硬盘中加载代码,根据给定的目录来读取并保存代码。它接受一个参数directory,表示代码所在的目录。它首先断言目录中存在以.py结尾的文件,然后使用os模块的walk()函数来遍历目录中的所有文件。对于每个文件,如果文件以.py结尾,则使用open()函数和read()方法来读取文件内容,并将其保存在code变量中。然后将经过_format_code()方法格式化后的code变量作为值,将文件名作为键,添加到self.codebooks字典中。最后调用utils.log_and_print_online()函数来记录并打印”{} files read from {}”.format(len(self.codebooks.keys()), directory),表示从目录中读取了多少个文件。
codes.py的代码如下:
import os
import re
from chatdev.utils import log_and_print_online
import difflib
class Codes:
def __init__(self, generated_content=""):
self.directory: str = None
self.version: float = 1.0
self.generated_content: str = generated_content
self.codebooks = {}
def extract_filename_from_line(lines):
file_name = ""
for candidate in re.finditer(r"(\w+\.\w+)", lines, re.DOTALL):
file_name = candidate.group()
file_name = file_name.lower()
return file_name
def extract_filename_from_code(code):
file_name = ""
regex_extract = r"class (\S+?):\n"
matches_extract = re.finditer(regex_extract, code, re.DOTALL)
for match_extract in matches_extract:
file_name = match_extract.group(1)
file_name = file_name.lower().split("(")[0] + ".py"
return file_name
if generated_content != "":
regex = r"(.+?)\n```.*?\n(.*?)```"
matches = re.finditer(regex, self.generated_content, re.DOTALL)
for match in matches:
code = match.group(2)
if "CODE" in code:
continue
group1 = match.group(1)
filename = extract_filename_from_line(group1)
if "__main__" in code:
filename = "main.py"
if filename == "": # post-processing
filename = extract_filename_from_code(code)
assert filename != ""
if filename is not None and code is not None and len(filename) > 0 and len(code) > 0:
self.codebooks[filename] = self._format_code(code)
def _format_code(self, code):
code = "\n".join([line for line in code.split("\n") if len(line.strip()) > 0])
return code
def _update_codes(self, generated_content):
new_codes = Codes(generated_content)
differ = difflib.Differ()
for key in new_codes.codebooks.keys():
if key not in self.codebooks.keys() or self.codebooks[key] != new_codes.codebooks[key]:
update_codes_content = "**[Update Codes]**\n\n"
update_codes_content += "{} updated.\n".format(key)
old_codes_content = self.codebooks[key] if key in self.codebooks.keys() else "# None"
new_codes_content = new_codes.codebooks[key]
lines_old = old_codes_content.splitlines()
lines_new = new_codes_content.splitlines()
unified_diff = difflib.unified_diff(lines_old, lines_new, lineterm='', fromfile='Old', tofile='New')
unified_diff = '\n'.join(unified_diff)
update_codes_content = update_codes_content + "\n\n" + """```
'''
'''\n""" + unified_diff + "\n```"
log_and_print_online(update_codes_content)
self.codebooks[key] = new_codes.codebooks[key]
def _rewrite_codes(self, git_management) -> None:
directory = self.directory
rewrite_codes_content = "**[Rewrite Codes]**\n\n"
if os.path.exists(directory) and len(os.listdir(directory)) > 0:
self.version += 1.0
if not os.path.exists(directory):
os.mkdir(self.directory)
rewrite_codes_content += "{} Created\n".format(directory)
for filename in self.codebooks.keys():
filepath = os.path.join(directory, filename)
with open(filepath, "w", encoding="utf-8") as writer:
writer.write(self.codebooks[filename])
rewrite_codes_content += os.path.join(directory, filename) + " Wrote\n"
if git_management:
if self.version == 1.0:
os.system("cd {}; git init".format(self.directory))
os.system("cd {}; git add .".format(self.directory))
os.system("cd {}; git commit -m \"{}\"".format(self.directory, self.version))
log_and_print_online(rewrite_codes_content)
def _get_codes(self) -> str:
content = ""
for filename in self.codebooks.keys():
content += "{}\n```{}\n{}\n```\n\n".format(filename,
"python" if filename.endswith(".py") else filename.split(".")[
-1], self.codebooks[filename])
return content
def _load_from_hardware(self, directory) -> None:
assert len([filename for filename in os.listdir(directory) if filename.endswith(".py")]) > 0
for root, directories, filenames in os.walk(directory):
for filename in filenames:
if filename.endswith(".py"):
code = open(os.path.join(directory, filename), "r", encoding="utf-8").read()
self.codebooks[filename] = self._format_code(code)
log_and_print_online("{} files read from {}".format(len(self.codebooks.keys()), directory))
文章来源:https://blog.csdn.net/linweidong/article/details/135312960
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!