python bad case边界不准确问题

2024-01-07 22:55:26

目录

问题描述

问题解决:


问题描述

针对bad case中,错误的主要原因是边界定位不准确问题,sub,obj抽取过短。

因此想要通过jieba分词,然后调用GPT4的api判断当前的新span是否符合条件。

问题解决:

import json
from pdb import set_trace as stop
import jieba

import openai

from tqdm import tqdm

openai.api_key = "api_key" # GPT4.0
openai.api_base = 'https://api.ngapi.top/v1'

def get_response(prompt, temperature=0.5, max_tokens=2048):
  print(prompt)
  completion = openai.ChatCompletion.create(
    # model="gpt-3.5-turbo",
    model="gpt-4",
    temperature=0,
    top_p=0,
    # max_tokens=max_tokens,
    messages=[
      {"role": "user", "content": f"{prompt}"}
    ]
  )
  return completion

llm_generated_path= "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104.jsonl"
change_path = "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104_post.txt"
 

po_dict = {"相等":'equal',
           "更好": 'better',
           '更差': 'worse',
           '不同': 'different'}

pad_word = '无'

chinese_punctuation = [',', '。', '?', '!', ':', ';', '‘', '’', '“', '”', '(', ')', '【', '】', '{', '}', '《', '》', '、', '——', '-', '……', '~', '·']

def get_previsous_word(cur_span, cur_sent):
    front_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?。直接给出答案即可。"
    front_result = get_response(front_prompt)['choices'][0]['message']['content']
    if front_result=='的':
        cur_span = front_result+cur_span
        front_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?直接给出答案即可。"
        front_result = get_response(front_prompt)['choices'][0]['message']['content']

    return front_result


def identify_nonu_phrase(front_result, cur_span, cur_sent):
    identify_prompt = f"在输入语句({cur_sent})中,({front_result}{cur_span})是一个可以表示物品名称、物品品牌的名词或名词短语吗?直接回答'yes'或'no'"
    # if '#' in identify_prompt:
    #     identify_prompt = identify_prompt.replace('#','')
    identify_result = get_response(identify_prompt)['choices'][0]['message']['content']

    return identify_result

def get_chinese_index(cur_span, cur_sent):
    index = cur_sent.find(cur_span) # 没发现的话, index = -1 
    return index 

def get_front_end_word(text, span):
 
    text_seg_list = jieba.cut(text, cut_all=False)
    span_seg_list = jieba.cut(span,cut_all=False )
    text_result = " ".join(text_seg_list)
    span_result = " ".join(span_seg_list)
    index = text_result.find(span_result) # 获取最后一个位置
    front_word =text_result[:index].split()[-1] # 获取前一个元素index
    if front_word == '的':
        front_front_word = text_result[:index-2].split()[-1] # 因为有一个空格,所以是-2
        front_word = front_front_word+front_word
 
    end_word = text_result[index + len(span_result):].split()[0] # 至于后面的0要不要添加,需要依据统计结果而定

    return front_word, end_word
 

def post_processing(cur_span, cur_sent, pad_word):
    if cur_span == pad_word: # 如果是空,则返回本身
        final_span = pad_word
    else:
        cur_span_index = get_chinese_index(cur_span, cur_sent)
        if cur_span_index == 0: # 如果当前给定的span已经位于句首,则保持不变
            final_span = cur_span
        else:
            front_result, end_result = get_front_end_word(cur_sent, cur_span)

            identify_result = identify_nonu_phrase(front_result, cur_span, cur_sent)
            print("identify_result结果是:", identify_result)

            if identify_result=='yes':
                final_span = front_result+cur_span
            else:
                final_span = cur_span

    return final_span
    

with open(llm_generated_path, 'r') as fr, open(change_path, 'w') as fw:
    for line in fr:
        cur_line = json.loads(line)
       
        cur_sent = cur_line['query'].split('\n\n')[1][7:-52].strip() # instruction2
        # cur_sent = cur_line['query'].split('\n\n')[-1][7:-57].strip() # instruction kaisong

        compar = cur_line['type'] # 是否是比较句
        if compar == 1:
            # cur_sent = cur_line['query'].split('\n\n')[1][7:-32].strip() 
            fw.write(cur_sent + "\n")
            result = cur_line['output'].strip().split('\n')
            gold = cur_line['truth'].strip().split('\n') # 

            # for j in range(0, len(gold), 2): # 如果是位置信息,则是 for j in range(0, len(gold), 2)
            #     gold_quintuple = gold[j][7:].strip()
            #     fw.write("gold:"+ gold_quintuple + "\n")
            

            for i in range(0, len(result), 2): # 同上 如果是位置信息,则是 for j in range(0, len(gold), 2)
                cur_quintuple = result[i][7:].strip() # 有几个特殊的,不能以逗号分隔
                # stop()
                # cur_quintuple_index = result[i+1][5:].strip() # '元组位置:(,17:18:19:20:21:22:23,12:13,24:25)'
                cur_quintuple_list = cur_quintuple[1:-1].split(',')
                # cur_quintuple_index_list = cur_quintuple_index[1:-1].split(',')
                sub, obj, asp, op, polarity = cur_quintuple_list[0].strip(), cur_quintuple_list[1].strip(), cur_quintuple_list[2].strip(), cur_quintuple_list[3].strip(), cur_quintuple_list[-1].strip()
                # sub_index, obj_index, asp_index, op_index = cur_quintuple_index_list[0].strip(),cur_quintuple_index_list[1].strip(),cur_quintuple_index_list[2].strip(),cur_quintuple_index_list[3].strip()
                sub = sub if sub else pad_word
                obj = obj if obj else pad_word
                asp = asp if asp else pad_word
                op = op if op else pad_word
                polarity = po_dict[polarity] if polarity else pad_word
                # 对产生的结果进行后处理
                # stop()
                post_sub = post_processing(sub, cur_sent, pad_word) # sub_index.split(";")[0]
                post_obj = post_processing(obj, cur_sent, pad_word)
                # post_asp = post_processing(asp, cur_sent, pad_word)
                # stop()
                final_quintuple = '('+sub +','+obj+','+ asp + ','+ op+','+polarity+')'
                post_final_quintuple = '('+post_sub +','+post_obj+','+ asp + ','+ op+','+polarity+')'
                # fw.write("final_quintuple"+final_quintuple +"\n")
                # fw.write("post_final_quintuple"+post_final_quintuple+"\n")
                fw.write(post_final_quintuple+"\n")

文章来源:https://blog.csdn.net/weixin_41862755/article/details/135396080
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。