双向A*算法-python

2023-12-21 17:35:13

GitHub - LittleFox99/B_A_star: Bidirectional A Star

其中a,b分别为双向A*搜索的权重?

#-*- coding:utf-8 -*-
# @Time    : 2020/11/11 1:21 下午
# @Author  : LittleFox99
# File : a_star.py
# 参考:
# https://blog.csdn.net/lustyoung/article/details/105027607
# https://www.jianshu.com/p/5704e67f40aa
import numpy as np
import queue
from queue import PriorityQueue
from pylab import *
import matplotlib.pyplot as plt


class A_star:
    def __init__(self, a, b, start, end, max_row, max_col, barrier_map):
        """
        A_star 初始化
        :param a:
        :param b:
        :param start:
        :param end:
        :param max_row:
        :param max_col:
        :param barrier_map:
        """

        # 权重参数
        self.a = a
        self.b = b
        # 起始点、终点
        self.start = start
        self.end = end
        # 相遇点
        self.stop_point = None
        self.stop_point_back = None
        #采用优先队列实现的小顶堆,用于存放待扩展结点,同时利用f值作为排序指标;
        self.opened_list1 = PriorityQueue()
        self.opened_list2 = PriorityQueue()
        self.open_all_list = []
        #储存记录所有走过的openlist
        self.open_list = set()
        #采用set(红黑树)实现,便于快速查找当前point是否存在于closed_list中;
        self.closed_list1 = set()
        self.closed_list2 = set()
        # 地图边界
        self.max_row = max_row
        self.max_col = max_col
        # 论文中设定为四个移动方向
        self.direction = [[0, 1], [0, -1], [1, 0], [-1, 0]]
        # 障碍地图
        self.barrier_map = barrier_map
        #遍历结点数
        self.travel_num = 0



    def bound(self, row, col):
        """
        边界条件检测:
        1.如果该点坐标超出地图区域,则非法
        2.如果该点是障碍,则非法
        :param row:
        :param col:
        :return:
        """
        if row < 0 or row > self.max_row:
            return True
        if col < 0 or col > self.max_col:
            return True
        if [row, col] in self.barrier_map:
            return True
        return False

    def euclidean_distance(self, point1, point2):
        """
        欧式距离计算
        :param point1:
        :param point2:
        :return:
        """
        x_ = abs(point1.getx() - point2.getx())
        y_ = abs(point1.gety() - point2.gety())
        return ((x_**2)+(y_**2))**(0.5)*1.0


    def find_path(self, point):
        """
        用栈回溯查找B-A*给出的路径
        :param point:
        :return:
        """
        stack = []
        father_point = point
        while father_point is not None:
            stack.append(father_point.get_coordinate())
            father_point = father_point.father
        return stack

    def heuristic_h(self, end_point, point1, point2):
        """
        启发函数h值计算
        h:正向搜索当前点n1到终点的欧式距离
        h_:正向搜索当前点n1到反向搜索当前点n2的距离
        :param end_point:
        :param point1:
        :param point2:
        :return:
        """
        h = self.euclidean_distance(point1, end_point)
        h_ = self.euclidean_distance(point1, point2)
        return (1 - self.b) * (h*1.0 / (1 + self.a) + self.a * h_*1.0 / (1 + self.a))

    def heuristic_g(self, point):
        """
        启发函数g值计算
        g:当前点父亲的g值
        g_:当前点到其父亲的欧式距离
        :param point:
        :return:
        """
        g = point.father.g
        g_ = self.euclidean_distance(point, point.father)
        return self.b * (g + g_)*1.0

    def heuristic_f(self, h, g):
        # 计算启发式函数f值
        return h + g


    def compute_child_node(self, current_point, back_current_point, search_forward):
        """
        遍历当前点的子节点
        :param current_point: 当前点
        :param back_current_point: 相反方向搜索的当前点
        :param search_forward: 搜索方向设置,True为正向,反之亦然
        :return:
        """
        # 四个方向的遍历
        for direct in self.direction:
            col_index = current_point.gety() + direct[0]
            row_index = current_point.getx() + direct[1]
            # 创建子节点,将当前点设置为子节点的父节点
            child_point = Point(row_index, col_index, current_point)
            # 查找open_list中是否存在子节点中,用于备份节点
            old_child_point = None
            #前向搜索
            if search_forward == True:

                #计算子节点各个启发函数值
                child_point.h = self.heuristic_h(self.end, current_point, back_current_point)
                child_point.g = self.heuristic_g(child_point)
                child_point.f = self.heuristic_f(child_point.h, child_point.g)

                # 边界检测, 如果它不可通过或者已经在关闭列表中, 跳过
                if (row_index, col_index) in self.closed_list1 or self.bound(row_index, col_index)==True:
                    continue
                else:
                    self.travel_num = self.travel_num + 1
                    # 通过检测则将子节点加入openlist(记录所有加入过openlist的点)
                    self.open_list.add(child_point.get_coordinate())
                    # 找到最短路径
                    if (row_index, col_index) in self.closed_list2:
                        self.stop_point_back = self.search_back_point(child_point, self.open_all_list)

                        self.stop_point = child_point
                        if self.stop_point_back is None:
                            self.stop_point_back = self.stop_point
                        print("forward!")
                        return True
                    """
                    如果可到达的节点存在于OPEN1列表中,称该节点为x点,计算经过n1点到达x点的g1(x)值,
                    如果该g1(x)值小于原g1(x)值,则将n1点作为x点的父节点,更新OPEN1列表中的f1(x)和g1(x)。
                    """
                    tmp_queue = [] # opened_list的备份
                    while not self.opened_list1.empty():
                        tmp_point = self.opened_list1.get()
                        if child_point.get_coordinate() == tmp_point.get_coordinate(): #找到x点,跳过
                            old_child_point = tmp_point
                            continue
                        tmp_queue.append(tmp_point)
                    while len(tmp_queue) != 0:
                        self.opened_list1.put(tmp_queue.pop())
                    if old_child_point is None: #如果没找到,直接加入子节点
                        self.opened_list1.put(child_point)
                    else:
                        # 找到x点
                        # 用g值为参考检查新的路径是否更好。更低的g值意味着更好的路径。
                        if old_child_point.g > child_point.g:
                            self.opened_list1.put(child_point)
                        else:
                            # print(2)
                            self.opened_list1.put(old_child_point)
                            # print(2)

            # 反向搜索,同理如上
            else:
                # 边界检测, 如果它不可通过或者已经在关闭列表中, 跳过
                child_point.h = self.heuristic_h(self.start, current_point, back_current_point)
                child_point.g = self.heuristic_g(child_point)
                child_point.f = self.heuristic_f(child_point.h, child_point.g)
                if (row_index, col_index) in self.closed_list2 or self.bound(row_index, col_index):
                    continue
                else:
                    self.travel_num = self.travel_num + 1
                    self.open_list.add(child_point.get_coordinate())
                    # 找到最短路径
                    if (row_index, col_index) in self.closed_list1:
                        # self.stop_point_back = self.search_father_point(child_point,self.opened_list2)
                        self.stop_point_back = self.search_back_point(child_point, self.open_all_list)

                        self.stop_point = child_point
                        if self.stop_point_back is None:
                            self.stop_point_back = self.stop_point
                        print("backward!")
                        return True
                    tmp_queue = []
                    while not self.opened_list2.empty():
                        tmp_point = self.opened_list2.get()
                        if child_point.get_coordinate() == tmp_point.get_coordinate():
                            old_child_point = tmp_point
                            continue
                        tmp_queue.append(tmp_point)
                    while len(tmp_queue) != 0:
                        self.opened_list2.put(tmp_queue.pop())
                    if old_child_point is None: #open_list没有找到子节点,则将
                        self.opened_list2.put(child_point)
                    else:
                        # 用g值为参考检查新的路径是否更好。更低的G值意味着更好的路径。
                        if old_child_point.g > child_point.g:
                            self.opened_list2.put(child_point)
                        else:
                            self.opened_list2.put(old_child_point)
                            # print(3)
        return False

    def search_back_point(self, point, opened_list):
        back_point = None
        while len(opened_list)!=0:
            tmp_point = opened_list.pop()
            if point.get_coordinate() == tmp_point.get_coordinate():
                return tmp_point
        return back_point




    def search(self):
        # 将起始点s设置为正向当前结点n1、终点e设置为反向当前结点n2
        current_point1 = self.start
        current_point2 = self.end
        # 并加入open1、open2
        self.opened_list1.put(current_point1)
        self.opened_list2.put(current_point2)
        forward_path, backward_path =None, None
        # opened_list1与opened_list2全部非空,输出寻路提示失败
        find_stop = False
        min_f_point1 = self.opened_list1.get()
        self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))
        min_f_point2 = self.opened_list2.get()
        self.closed_list2.add((min_f_point2.getx(), min_f_point2.gety()))



        while True:
            # # 取出open_list1中f值最小的点,加入closed_list1
            # 将其作为当前结点,遍历寻找它的子节点
            # min_f_point1 = self.opened_list1.get()
            # self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))
            self.open_all_list.append(current_point1)
            self.open_all_list.append(current_point2)
            find_stop = self.compute_child_node(current_point1, current_point2, True)
            if find_stop:
                # forward_path = self.find_path(current_point1)
                forward_path = self.find_path(self.stop_point)
                backward_path = self.find_path(self.stop_point_back)
                break
            # min_f_point1 = self.opened_list1.get()
            # print(1)
            min_f_point1 = self.opened_list1.get()
            # self.open_all_list.append(min_f_point1)
            self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))
            current_point1 = min_f_point1
            self.open_all_list.append(current_point1)


            # 取出open_list1中f值最小的点,加入closed_list1
            # current_point1 = self.opened_list1.get()

            find_stop = self.compute_child_node(current_point2, current_point1, False)
            if find_stop:
                forward_path = self.find_path(self.stop_point)
                backward_path = self.find_path(self.stop_point_back)
                # backward_path = self.find_path(self.stop_point)

                break
            # min_f_point2 = self.opened_list2.get()
            min_f_point2 = self.opened_list2.get()
            # self.open_all_list.append(min_f_point1)
            self.closed_list2.add((min_f_point2.getx(), min_f_point2.gety()))
            current_point2 = min_f_point2
            if self.opened_list1.qsize() == 0 or self.opened_list2.qsize() == 0:
                break


            # current_point2 = self.opened_list2.get()
        if backward_path==None and forward_path==None:
            print("Fail to find the path!")
            return None
        else:
            forward_path = forward_path + backward_path
        return forward_path, self.open_list, self.stop_point, self.travel_num

class Point:
    """
    Point——地图上的格子,或者理解为点
    1.坐标
    2.g,h,f,father
    """
    def __init__(self, x, y, father):
        self.x = x
        self.y = y
        self.g = 0
        self.h = 0
        self.f = 0
        self.father = father

    # 用于优先队列中f值的排序
    def __lt__(self, other):
        return self.f < other.f

    # 获取x坐标
    def getx(self):
        return self.x

    # 获取y坐标
    def gety(self):
        return self.y

    # 获取x, y坐标
    def get_coordinate(self):
        return self.x, self.y


def draw_path(max_row, max_col, barrier_map, path=None, openlist=None, stop_point=None, start=None ,end=None):
    """
    画出B-A*算法的结果模拟地图
    :param max_row:地图x最大距离
    :param max_col:地图y最大距离
    :param barrier_map:障碍地图
    :param path:B-A*给出的较优路径
    :param openlist:B-A*中的openlist
    :param stop_point:B-A*中的相遇结点
    :param start:起始点
    :param end:终点
    :return:
    """
    """划分数组的x,y坐标"""
    barrier_map_x = [i[0] for i in barrier_map]
    barrier_map_y = [i[1] for i in barrier_map]
    path_x = [i[0] for i in path]
    path_y = [i[1] for i in path]
    open_list_x = [i[0] for i in openlist]
    open_list_y = [i[1] for i in openlist]
    """对画布进行属性设置"""

    # plt.subplot(2, 2, subplot)
    plt.figure(figsize=(15, 15))  # 为了防止x,y轴间隔不一样长,影响最后的表现效果,所以手动设定等长
    plt.xlim(-1, max_row)
    plt.ylim(-1, max_col)
    my_x_ticks = np.arange(0, max_row, 1)
    my_y_ticks = np.arange(0, max_col, 1)
    plt.xticks(my_x_ticks)  # 竖线的位置与间隔
    plt.yticks(my_y_ticks)
    plt.grid(True)  # 开启栅格
    """画图"""
    plt.scatter(barrier_map_x, barrier_map_y, s=500, c='k', marker='s')
    plt.scatter(open_list_x, open_list_y, s=500, c='cyan', marker='s')
    plt.scatter(path_x, path_y,s=500, c='r', marker='s')
    plt.scatter(stop_point.getx(), stop_point.gety(), s=500, c='g', marker='s')
    plt.scatter(start.getx(),start.gety(),s=500, c='b', marker='s')
    plt.scatter(end.getx(), end.gety(), s=500, c='b', marker='s')
    plt.title("Bidirectiional A Star , a = {}, b = {}".format(a, b))
    print(path)
    plt.savefig("result_pic/a_{},b_{}.png".format(a, b))
    plt.show()


def draw_arg(travel_num, path_num, a_list, b_list):
    markes = ['-o', '-s', '-^', '-p', '-^','-v']

    travel_num = np.array(travel_num)
    path_num = np.array(path_num)
    fig,ax = plt.subplots(2, 1, figsize=(15,12))
    for i in range(0, 6):
        ax[0].plot(a_list, travel_num[i*6:(i+1)*6], markes[i],label=b_list[i])
        ax[0].set_ylabel('Travel num')
        ax[1].plot(a_list, path_num[i*6:(i+1)*6], markes[i],label=b_list[i])
        ax[1].set_xlabel('The values of a')
        ax[1].set_ylabel('Path nun')
    ax[0].set_title('the two args of B-A*')
    box = ax[1].get_position()
    ax[1].set_position([box.x0, box.y0, box.width, box.height])
    ax[1].legend(loc='right', bbox_to_anchor=(1.1, 0.6), ncol=1, title="b")
    plt.savefig("ab.png")
    plt.show()


if __name__ == '__main__':

    """权重的值"""
    a_list = [2, 3, 4, 5, 6, 7]
    b_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
    """地图、障碍物的位置、起始点、终点"""
    max_row = 30
    max_col = 30
    # barrier_map = [[1, 2], [3, 5], [4,7],[6, 6], [11, 19], [11, 6],[11, 5],[12,5],[12,6],[11, 16], [11, 17], [11, 18], [7, 7]]
    barrier_map = [[6, 29], [7, 29], [8, 29], [6, 28], [7, 28], [8, 28], [6, 27], [7, 27], [8, 27], [6, 26], [7, 26],
                   [8, 26], [6, 25], [7, 25], [8, 25], [6, 24], [7, 24], [8, 24], [29, 25], [28, 25], [27, 25],
                   [26, 25], [25, 25], [24, 25], [23, 25], [22, 25], [21, 25], [29, 24], [28, 24], [27, 24], [26, 24],
                   [24, 24], [25, 24],
                   [23, 24], [22, 24], [21, 24], [20, 10], [21, 10], [22, 10], [23, 10], [24, 10], [20, 9], [21, 9],
                   [22, 9], [23, 9], [24, 9], [20, 8], [21, 8], [22, 8], [23, 8], [24, 8], [20, 7], [21, 7], [22, 7],
                   [23, 7], [24, 7], [20, 6], [21, 6], [22, 6], [23, 6], [24, 6], [20, 5], [21, 5], [22, 5], [23, 5],
                   [24, 5], [20, 4], [21, 4], [22, 4], [23, 4], [24, 4], [20, 3], [21, 3], [22, 3], [23, 3], [24, 3],
                   [20, 2], [21, 2], [22, 2], [23, 2], [24, 2], [20, 1], [21, 1], [22, 1], [23, 1], [24, 1], [20, 0],
                   [21, 0], [22, 0], [23, 0], [24, 0],
                   [16, 16], [16, 17], [16, 18], [17, 16], [17, 17], [17, 18], [18, 16], [18, 17], [18, 18], [19, 16],
                   [19, 17], [19, 18], [20, 16], [20, 17], [20, 18], [21, 16], [21, 17], [21, 18], [22, 16], [22, 17],
                   [22, 18], [23, 16], [23, 17], [23, 18], [24, 16], [24, 17], [24, 18], [25, 16], [25, 17], [25, 18],
                   [26, 16], [26, 17], [26, 18], [27, 16], [27, 17], [27, 18], [28, 16], [28, 17], [28, 18], [29, 16],
                   [29, 17], [29, 18]]
    # barrier_map = np.array(barrier_map)
    # start = Point(4, 5, None)
    # end = Point(18, 8, None)
    start = Point(27, 2, None)
    end = Point(0, 29, None)
    travel_point_num = []
    path_point_num = []

    """遍历权重的值,使用B-A*算法"""
    # for a, b in zip(a_list, b_list):
    for b in b_list:
        for a in a_list:
            print(a,b)
            path, open_list, stop_point, tp_num = A_star(a, b, start, end, max_row, max_col,barrier_map).search()
            if path is not None:
                # 找到路径
                print("Find the path!")
                travel_point_num.append(tp_num)
                path_point_num.append(len(path))
                draw_path(max_row, max_col, barrier_map, path, open_list, stop_point, start, end)
            else:
                print("Fail to find the path!")

    print(travel_point_num)
    print(path_point_num)
    draw_arg(travel_point_num, path_point_num, a_list, b_list)



文章来源:https://blog.csdn.net/ljjjjjjjjjjj/article/details/135133272
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。