- 必须记住下一步还可以走哪些点——OPEN表(记录还没有扩展的点)
- 必须记住哪些点走过了——CLOSED表(记录已经扩展的点
广度优先搜索
在应用BFS算法进行八数码问题搜索时需要open和closed两个表。首先将初始状态加入open队列,然后进行出队操作并放入closed中,对出队的状态进行扩展(所谓扩展也就是找出其上下左右移动后的状态),将扩展出的状态加入队列,然后继续循环出队-扩展-入队的操作,直到找到解为止。
import copy
class grid:
def __init__(self,stat):
self.pre=None
self.target=[[1,2,3],[8,0,4],[7,6,5]]
self.stat=stat
self.find0()
self.update()
def update(self):
self.fH()
self.fG()
def fG(self):
if(self.pre!=None):
self.G=self.pre.G+1
else:
self.G=0
def fH(self):
self.H=0
for i in range(3):
for j in range(3):
targetX=self.target[i][j]
nowP=self.findx(targetX)
self.H+=abs(nowP[0]-i)+abs(nowP[1]-j)
def see(self):
print("depth:",self.G)
for i in range(3):
print(self.stat[i])
print("-"*10)
def seeAns(self):
ans=[]
ans.append(self)
p=self.pre
while(p):
ans.append(p)
p=p.pre
ans.reverse()
for i in ans:
i.see()
def findx(self,x):
for i in range(3):
if(x in self.stat[i]):
j=self.stat[i].index(x)
return [i,j]
def find0(self):
self.zero=self.findx(0)
def expand(self):
i=self.zero[0]
j=self.zero[1]
gridList=[]
if(j==2 or j==1):
gridList.append(self.left())
if(i==2 or i==1):
gridList.append(self.up())
if(i==0 or i==1):
gridList.append(self.down())
if(j==0 or j==1):
gridList.append(self.right())
return gridList
def move(self,row,col):
newStat=copy.deepcopy(self.stat)
tmp=self.stat[self.zero[0]+row][self.zero[1]+col]
newStat[self.zero[0]][self.zero[1]]=tmp
newStat[self.zero[0]+row][self.zero[1]+col]=0
return newStat
def up(self):
return self.move(-1,0)
def down(self):
return self.move(1,0)
def left(self):
return self.move(0,-1)
def right(self):
return self.move(0,1)
def N(nums):
N=0
for i in range(len(nums)):
if(nums[i]!=0):
for j in range(i):
if(nums[j]>nums[i]):
N+=1
return N
def judge(src,target):
N1=N(src)
N2=N(target)
if(N1%2==N2%2):
return True
else:
return False
startStat=[[2,8,3],[1,0,4],[7,6,5]]
g=grid(startStat)
if(judge(startStat,g.target)!=True):
print("所给八数码无解,请检查输入")
exit(1)
visited=[]
queue=[g]
time=0
while(queue):
time+=1
v=queue.pop(0)
if(v.H==0):
print("found and times:",time,"moves:",v.G)
v.seeAns()
break
else:
visited.append(v.stat)
expandStats=v.expand()
for stat in expandStats:
tmpG=grid(stat)
tmpG.pre=v
tmpG.update()
if(stat not in visited):
queue.append(tmpG)
深度优先搜索
'''
深度优先搜索实现
'''
import copy
class grid:
def __init__(self, stat):
self.pre = None
self.target = [[1, 2, 3], [8, 0, 4], [7, 6, 5]]
self.stat = stat
self.find0()
self.update()
def update(self):
self.fH()
self.fG()
def fG(self):
if (self.pre != None):
self.G = self.pre.G + 1
else:
self.G = 0
def fH(self):
self.H = 0
for i in range(3):
for j in range(3):
targetX = self.target[i][j]
nowP = self.findx(targetX)
self.H += abs(nowP[0] - i) + abs(nowP[1] - j)
def see(self):
print("depth:", self.G)
for i in range(3):
print(self.stat[i])
print("-" * 10)
def seeAns(self):
ans = []
ans.append(self)
p = self.pre
while (p):
ans.append(p)
p = p.pre
ans.reverse()
for i in ans:
i.see()
def findx(self, x):
for i in range(3):
if (x in self.stat[i]):
j = self.stat[i].index(x)
return [i, j]
def find0(self):
self.zero = self.findx(0)
def expand(self):
i = self.zero[0]
j = self.zero[1]
gridList = []
if (j == 2 or j == 1):
gridList.append(self.left())
if (i == 2 or i == 1):
gridList.append(self.up())
if (i == 0 or i == 1):
gridList.append(self.down())
if (j == 0 or j == 1):
gridList.append(self.right())
return gridList
def move(self, row, col):
newStat = copy.deepcopy(self.stat)
tmp = self.stat[self.zero[0] + row][self.zero[1] + col]
newStat[self.zero[0]][self.zero[1]] = tmp
newStat[self.zero[0] + row][self.zero[1] + col] = 0
return newStat
def up(self):
return self.move(-1, 0)
def down(self):
return self.move(1, 0)
def left(self):
return self.move(0, -1)
def right(self):
return self.move(0, 1)
def isin(g, gList):
gstat = g.stat
statList = []
for i in gList:
statList.append(i.stat)
if (gstat in statList):
res = [True, statList.index(gstat)]
else:
res = [False, 0]
return res
def N(nums):
N=0
for i in range(len(nums)):
if(nums[i]!=0):
for j in range(i):
if(nums[j]>nums[i]):
N+=1
return N
def judge(src,target):
N1=N(src)
N2=N(target)
if(N1%2==N2%2):
return True
else:
return False
startStat = [[2, 8, 3], [1, 0, 4], [7, 6, 5]]
g = grid(startStat)
if(judge(startStat,g.target)!=True):
print("所给八数码无解,请检查输入")
exit(1)
visited = []
time = 0
def DFSUtil(v, visited):
global time
if (v.G > 4):
return
time+=1
if (v.H == 0):
print("found and times", time, "moves:", v.G)
v.seeAns()
exit(1)
visited.append(v.stat)
expandStats = v.expand()
w = []
for stat in expandStats:
tmpG = grid(stat)
tmpG.pre = v
tmpG.update()
if (stat not in visited):
w.append(tmpG)
for vadj in w:
DFSUtil(vadj, visited)
visited.pop()
DFSUtil(g, visited)
print("在当前深度下没有找到解,请尝试增加搜索深度")
3. 启发式搜索
特点:重排OPEN表,选择最有希望的节点加以扩展。
引入估价函数(evaluation function)来估计节点位于解路径上的“希望”,函数值越小“希望”越大 搜索过程中按照估价函数的大小对OPEN表排序, 每次选择估价函数值最小的节点作为下一步考察的节点
3.1 有序搜索
选择OPEN表上具有最小f 值的节点作为下一个要扩展的节点。
对应到八数码问题:
f
(
x
)
=
g
(
x
)
+
h
(
x
)
f (x) = g (x) + h (x)
f(x)=g(x)+h(x)
- g (x):从初始状态到x需要进行的移动操作的次数
- h (x):所有棋子与目标位置的曼哈顿距离之和
- 曼哈顿距离:两点之间水平距离和垂直距离之和 仍满足估价函数的限制条件
'''
有序搜索
'''
import copy
import numpy as np
from datetime import datetime
def string_to_ls(str):
return [i.split(' ') for i in str.split(',')]
def get_loacl(arr, target):
for i in arr:
for j in i:
if j == target:
return arr.index(i), i.index(j)
def get_elements(arr):
r, c = get_loacl(arr, '0')
elements = []
if r > 0:
elements.append(arr[r - 1][c])
if r < 2:
elements.append(arr[r + 1][c])
if c > 0:
elements.append(arr[r][c - 1])
if c < 2:
elements.append(arr[r][c + 1])
return elements
def get_child(arr, e):
arr_new = copy.deepcopy(arr)
r, c = get_loacl(arr_new, '0')
r1, c1 = get_loacl(arr_new, e)
arr_new[r][c], arr_new[r1][c1] = arr_new[r1][c1], arr_new[r][c]
return arr_new
def get_distance(arr1, arr2):
distance = []
for i in arr1:
for j in i:
loc1 = get_loacl(arr1, j)
loc2 = get_loacl(arr2, j)
distance.append(abs(loc1[0] - loc2[0]) + abs(loc1[1] - loc2[1]))
return sum(distance)
def is_goal(arr, goal):
return arr == goal
class state:
def __init__(self, state, deep, parent, distance):
self.state = state
self.deep = deep
self.parent = parent
self.distance = distance
def chidren(self):
chidren = []
for i in get_elements(self.state):
child = state(state=get_child(self.state, i),
deep=self.deep + 1,
parent=self,
distance=self.deep + 1 + get_distance(self.state, goal_arr))
chidren.append(child)
return chidren
def print_path(n):
if n.parent == None:
return
else:
print('↑')
print(np.array(n.parent.state))
print_path(n.parent)
if __name__ == '__main__':
initial = '4 0 1,6 8 5,7 3 2'
goal = '5 8 2,1 0 4,6 3 7'
initial_arr = string_to_ls(initial)
goal_arr = string_to_ls(goal)
initial_arr = state(initial_arr, deep=0, parent=None, distance=get_distance(initial_arr, goal_arr))
start = datetime.now()
open = [initial_arr]
close = []
limit = 19
while len(open) > 0:
open_tb = [i.state for i in open]
close_tb = [i.state for i in close]
n = open.pop(0)
close.append(n)
if is_goal(n.state, goal_arr):
print(np.array(n.state))
print_path(n)
print('--' * 20)
print('成功搜索到路径,求解过程如上')
break
else:
if n.deep < limit:
for i in n.chidren():
if i.state not in open_tb:
if i.state not in close_tb:
open.insert(0, i)
open.sort(key=lambda x: x.distance)
else:
print('该深度下无解')
end = datetime.now()
print('--' * 20)
print('限制深度为:{}\t搜寻深度为:{}\n启发式A算法搜索步数为:{}'.format(limit, close[-1].deep, len(close) - 2))
print('--' * 20)
print('搜索耗时:', end - start)
3.2 A*算法
估价函数的定义 对节点n定义
f
?
(
n
)
=
g
?
(
n
)
+
h
?
(
n
)
f ^*(n)=g^ *(n)+h^*(n)
f?(n)=g?(n)+h?(n),表示从S开始约束通过节点n的一条最佳路径的代价。希望估价函数f 定义为:f(n)=g(n)+h(n),g是g的估计 ,h是h的估计
A算法的定义 定义1 在GRAPHSEARCH过程中,如果第8步的重排OPEN表是依据f(x)=g(x)+h(x)进行的,则称该过程为A算法。 定义2 在A算法中,如果对所有的x存在h(x)≤h(x),则称h(x)为h*(x)的下界,它表示某种偏于保守的估计。 定义3 采用h*(x)的下界h(x)为启发函数的A算法,称为A算法。当h=0时, A算法就变为有序搜索算法。
import copy
class grid:
def __init__(self,stat):
self.pre=None
self.target=[[1,2,3],[8,0,4],[7,6,5]]
self.stat=stat
self.find0()
self.update()
def update(self):
self.fH()
self.fG()
self.fF()
def fG(self):
if(self.pre!=None):
self.G=self.pre.G+1
else:
self.G=0
def fH(self):
self.H=0
for i in range(3):
for j in range(3):
targetX=self.target[i][j]
nowP=self.findx(targetX)
self.H+=abs(nowP[0]-i)+abs(nowP[1]-j)
def fF(self):
self.F=self.G+self.H
def see(self):
for i in range(3):
print(self.stat[i])
print("F=",self.F,"G=",self.G,"H=",self.H)
print("-"*10)
def seeAns(self):
ans=[]
ans.append(self)
p=self.pre
while(p):
ans.append(p)
p=p.pre
ans.reverse()
for i in ans:
i.see()
def findx(self,x):
for i in range(3):
if(x in self.stat[i]):
j=self.stat[i].index(x)
return [i,j]
def find0(self):
self.zero=self.findx(0)
def expand(self):
i=self.zero[0]
j=self.zero[1]
gridList=[]
if(j==2 or j==1):
gridList.append(self.left())
if(i==2 or i==1):
gridList.append(self.up())
if(i==0 or i==1):
gridList.append(self.down())
if(j==0 or j==1):
gridList.append(self.right())
return gridList
def move(self,row,col):
newStat=copy.deepcopy(self.stat)
tmp=self.stat[self.zero[0]+row][self.zero[1]+col]
newStat[self.zero[0]][self.zero[1]]=tmp
newStat[self.zero[0]+row][self.zero[1]+col]=0
return newStat
def up(self):
return self.move(-1,0)
def down(self):
return self.move(1,0)
def left(self):
return self.move(0,-1)
def right(self):
return self.move(0,1)
def isin(g,gList):
gstat=g.stat
statList=[]
for i in gList:
statList.append(i.stat)
if(gstat in statList):
res=[True,statList.index(gstat)]
else:
res=[False,0]
return res
def N(nums):
N=0
for i in range(len(nums)):
if(nums[i]!=0):
for j in range(i):
if(nums[j]>nums[i]):
N+=1
return N
def judge(src,target):
N1=N(src)
N2=N(target)
if(N1%2==N2%2):
return True
else:
return False
def Astar(startStat):
open=[]
closed=[]
g=grid(startStat)
if(judge(startStat,g.target)!=True):
print("所给八数码无解,请检查输入")
exit(1)
open.append(g)
time=0
while(open):
open.sort(key=lambda G:G.F)
minFStat=open[0]
if(minFStat.H==0):
print("found and times:",time,"moves:",minFStat.G)
minFStat.seeAns()
break
open.pop(0)
closed.append(minFStat)
expandStats=minFStat.expand()
for stat in expandStats:
tmpG=grid(stat)
tmpG.pre=minFStat
tmpG.update()
findstat=isin(tmpG,open)
findstat2=isin(tmpG,closed)
if(findstat2[0]==True and tmpG.F<closed[findstat2[1]].F):
closed[findstat2[1]]=tmpG
open.append(tmpG)
time+=1
if(findstat[0]==True and tmpG.F<open[findstat[1]].F):
open[findstat[1]]=tmpG
time+=1
if(findstat[0]==False and findstat2[0]==False):
open.append(tmpG)
time+=1
stat=[[2, 8, 3], [1, 0 ,4], [7, 6, 5]]
Astar(stat)
https://github.com/roadwide/AI-Homework
https://blog.csdn.net/Juuunn/article/details/109439359
|