aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormat <git@matdoes.dev>2024-12-26 07:42:35 +0000
committermat <git@matdoes.dev>2024-12-26 07:42:35 +0000
commitadb56b7eb2c5b54a4dccc7b5f77dd0f7d2442993 (patch)
tree7fd35fc590f460604118a1e445f5205e8c6f9801
parent3c83e5b24a622062c490f90c7e5bde043438d517 (diff)
downloadazalea-drasl-adb56b7eb2c5b54a4dccc7b5f77dd0f7d2442993.tar.xz
make a_star function use an IndexMap like the pathfinding crate
-rw-r--r--Cargo.lock1
-rw-r--r--azalea/Cargo.toml1
-rw-r--r--azalea/src/pathfinder/astar.rs182
-rw-r--r--azalea/src/pathfinder/mod.rs2
4 files changed, 115 insertions, 71 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 69545dee..6c7b8d9e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -232,6 +232,7 @@ dependencies = [
"derive_more",
"futures",
"futures-lite",
+ "indexmap",
"nohash-hasher",
"num-format",
"num-traits",
diff --git a/azalea/Cargo.toml b/azalea/Cargo.toml
index 739eb5f7..82086bb6 100644
--- a/azalea/Cargo.toml
+++ b/azalea/Cargo.toml
@@ -34,6 +34,7 @@ bevy_tasks = { workspace = true, features = ["multi_threaded"] }
derive_more = { workspace = true, features = ["deref", "deref_mut"] }
futures = { workspace = true }
futures-lite = { workspace = true }
+indexmap = "2.7.0"
nohash-hasher = { workspace = true }
num-format = { workspace = true }
num-traits = { workspace = true }
diff --git a/azalea/src/pathfinder/astar.rs b/azalea/src/pathfinder/astar.rs
index 9948d315..0cd7c291 100644
--- a/azalea/src/pathfinder/astar.rs
+++ b/azalea/src/pathfinder/astar.rs
@@ -2,12 +2,13 @@ use std::{
cmp::{self},
collections::BinaryHeap,
fmt::Debug,
- hash::Hash,
+ hash::{BuildHasherDefault, Hash},
time::{Duration, Instant},
};
+use indexmap::IndexMap;
use num_format::ToFormattedString;
-use rustc_hash::FxHashMap;
+use rustc_hash::FxHasher;
use tracing::{debug, trace, warn};
pub struct Path<P, M>
@@ -37,6 +38,12 @@ pub enum PathfinderTimeout {
Nodes(usize),
}
+type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
+
+// Sources:
+// - https://en.wikipedia.org/wiki/A*_search_algorithm
+// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
+// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
start: P,
heuristic: HeuristicFn,
@@ -52,77 +59,100 @@ where
{
let start_time = Instant::now();
- let mut open_set = BinaryHeap::<WeightedNode<P>>::new();
- open_set.push(WeightedNode(start, 0.));
- let mut nodes: FxHashMap<P, Node<P, M>> = FxHashMap::default();
+ let mut open_set = BinaryHeap::<WeightedNode>::new();
+ open_set.push(WeightedNode {
+ g_score: 0.,
+ f_score: 0.,
+ index: 0,
+ });
+ let mut nodes: FxIndexMap<P, Node<M>> = IndexMap::default();
nodes.insert(
start,
Node {
- position: start,
movement_data: None,
- came_from: None,
- g_score: f32::default(),
- f_score: f32::INFINITY,
+ came_from: usize::MAX,
+ g_score: 0.,
},
);
- let mut best_paths: [P; 7] = [start; 7];
+ let mut best_paths: [usize; 7] = [0; 7];
let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
let mut num_nodes = 0;
- while let Some(WeightedNode(current_node, _)) = open_set.pop() {
+ while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
num_nodes += 1;
- if success(current_node) {
+
+ let (&node, node_data) = nodes.get_index(index).unwrap();
+ if success(node) {
debug!("Nodes considered: {num_nodes}");
return Path {
- movements: reconstruct_path(nodes, current_node),
+ movements: reconstruct_path(nodes, index),
partial: false,
};
}
- let current_g_score = nodes
- .get(&current_node)
- .map(|n| n.g_score)
- .unwrap_or(f32::INFINITY);
-
- for neighbor in successors(current_node) {
- let tentative_g_score = current_g_score + neighbor.cost;
- let neighbor_g_score = nodes
- .get(&neighbor.movement.target)
- .map(|n| n.g_score)
- .unwrap_or(f32::INFINITY);
- if neighbor_g_score - tentative_g_score > MIN_IMPROVEMENT {
- let heuristic = heuristic(neighbor.movement.target);
- let f_score = tentative_g_score + heuristic;
- nodes.insert(
- neighbor.movement.target,
- Node {
- position: neighbor.movement.target,
+ if g_score > node_data.g_score {
+ continue;
+ }
+
+ for neighbor in successors(node) {
+ let tentative_g_score = g_score + neighbor.cost;
+ // let neighbor_heuristic = heuristic(neighbor.movement.target);
+ let neighbor_heuristic;
+ let neighbor_index;
+
+ // skip neighbors that don't result in a big enough improvement
+ if tentative_g_score - g_score < MIN_IMPROVEMENT {
+ continue;
+ }
+
+ match nodes.entry(neighbor.movement.target) {
+ indexmap::map::Entry::Occupied(mut e) => {
+ if e.get().g_score > tentative_g_score {
+ neighbor_heuristic = heuristic(*e.key());
+ neighbor_index = e.index();
+ e.insert(Node {
+ movement_data: Some(neighbor.movement.data),
+ came_from: index,
+ g_score: tentative_g_score,
+ });
+ } else {
+ continue;
+ }
+ }
+ indexmap::map::Entry::Vacant(e) => {
+ neighbor_heuristic = heuristic(*e.key());
+ neighbor_index = e.index();
+ e.insert(Node {
movement_data: Some(neighbor.movement.data),
- came_from: Some(current_node),
+ came_from: index,
g_score: tentative_g_score,
- f_score,
- },
- );
- open_set.push(WeightedNode(neighbor.movement.target, f_score));
-
- for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
- let node_score = heuristic + tentative_g_score / coefficient;
- if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
- best_paths[coefficient_i] = neighbor.movement.target;
- best_path_scores[coefficient_i] = node_score;
- }
+ });
+ }
+ }
+
+ open_set.push(WeightedNode {
+ index: neighbor_index,
+ g_score: tentative_g_score,
+ f_score: tentative_g_score + neighbor_heuristic,
+ });
+
+ for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
+ let node_score = neighbor_heuristic + tentative_g_score / coefficient;
+ if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
+ best_paths[coefficient_i] = neighbor_index;
+ best_path_scores[coefficient_i] = node_score;
}
}
}
- // check for timeout every ~20ms
+ // check for timeout every ~10ms
if num_nodes % 10000 == 0 {
let timed_out = match timeout {
- PathfinderTimeout::Time(max_duration) => start_time.elapsed() > max_duration,
- PathfinderTimeout::Nodes(max_nodes) => num_nodes > max_nodes,
+ PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
+ PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
};
if timed_out {
// timeout, just return the best path we have so far
@@ -132,7 +162,7 @@ where
}
}
- let best_path = determine_best_path(&best_paths, &start);
+ let best_path = determine_best_path(best_paths, 0);
debug!(
"A* ran at {} nodes per second",
@@ -146,48 +176,46 @@ where
}
}
-fn determine_best_path<P>(best_paths: &[P; 7], start: &P) -> P
-where
- P: Eq + Hash + Copy + Debug,
-{
+fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
// this basically makes sure we don't create a path that's really short
- for node in best_paths.iter() {
+ for node in best_paths {
if node != start {
- return *node;
+ return node;
}
}
warn!("No best node found, returning first node");
best_paths[0]
}
-fn reconstruct_path<P, M>(mut nodes: FxHashMap<P, Node<P, M>>, current: P) -> Vec<Movement<P, M>>
+fn reconstruct_path<P, M>(
+ mut nodes: FxIndexMap<P, Node<M>>,
+ mut current_index: usize,
+) -> Vec<Movement<P, M>>
where
P: Eq + Hash + Copy + Debug,
{
let mut path = Vec::new();
- let mut current = current;
- while let Some(node) = nodes.remove(&current) {
- if let Some(came_from) = node.came_from {
- current = came_from;
- } else {
+ while let Some((&node_position, node)) = nodes.get_index_mut(current_index) {
+ if node.came_from == usize::MAX {
break;
}
+
+ current_index = node.came_from;
+
path.push(Movement {
- target: node.position,
- data: node.movement_data.unwrap(),
+ target: node_position,
+ data: node.movement_data.take().unwrap(),
});
}
path.reverse();
path
}
-pub struct Node<P, M> {
- pub position: P,
+pub struct Node<M> {
pub movement_data: Option<M>,
- pub came_from: Option<P>,
+ pub came_from: usize,
pub g_score: f32,
- pub f_score: f32,
}
pub struct Edge<P: Hash + Copy, M> {
@@ -218,16 +246,30 @@ impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
}
#[derive(PartialEq)]
-pub struct WeightedNode<P: PartialEq>(P, f32);
+pub struct WeightedNode {
+ index: usize,
+ g_score: f32,
+ f_score: f32,
+}
-impl<P: PartialEq> Ord for WeightedNode<P> {
+impl Ord for WeightedNode {
fn cmp(&self, other: &Self) -> cmp::Ordering {
// intentionally inverted to make the BinaryHeap a min-heap
- other.1.partial_cmp(&self.1).unwrap_or(cmp::Ordering::Equal)
+ match other
+ .f_score
+ .partial_cmp(&self.f_score)
+ .unwrap_or(cmp::Ordering::Equal)
+ {
+ cmp::Ordering::Equal => self
+ .g_score
+ .partial_cmp(&other.g_score)
+ .unwrap_or(cmp::Ordering::Equal),
+ s => s,
+ }
}
}
-impl<P: PartialEq> Eq for WeightedNode<P> {}
-impl<P: PartialEq> PartialOrd for WeightedNode<P> {
+impl Eq for WeightedNode {}
+impl PartialOrd for WeightedNode {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
diff --git a/azalea/src/pathfinder/mod.rs b/azalea/src/pathfinder/mod.rs
index 0b3b7591..76a1f79b 100644
--- a/azalea/src/pathfinder/mod.rs
+++ b/azalea/src/pathfinder/mod.rs
@@ -778,7 +778,7 @@ pub fn check_for_path_obstruction(
new_path
.extend(executing_path.path.iter().skip(patch_end_index).cloned());
is_patch_complete = true;
- debug!("the obstruction patch is not partial");
+ debug!("the obstruction patch is not partial :)");
} else {
debug!(
"the obstruction patch is partial, throwing away rest of path :("