Nearest Neighbour using a kd-tree

%
% A k-dimensional point is represented by a list of k integers.
%

% list_kdtree(Points, Tree) is true if Tree is a kd-tree containing the given 
%   k-dimensional Points.
% e.g. list_kdtree([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]], X) gives
%   X=kd([7,2],kd([5,4],kd([2,3],void,void),
%                       kd([4,7],void,void)),
%              kd([9,6],kd([8,1],void,void),
%                       void))
list_kdtree(Points, Tree):-
  list_kdtree(Points, 0, Tree).

list_kdtree([], _, void):-!.
list_kdtree(Points, Depth, kd(Median,Left,Right)):-
  Axis is Depth mod 2,
  merge_sort(Points, Axis, SortedPoints),
  length(SortedPoints, Length),
  I is (Length + 2) // 2,
  partition(I, SortedPoints, Before, Median, After),
  Depth1 is Depth + 1,
  list_kdtree(Before, Depth1, Left),
  list_kdtree(After, Depth1, Right).

% nearest(Tree, Target, DistSqu, Best) is true if Best is a point in the
%   kd-tree Tree which is nearest to the point Target, and DistSqu is the
%   square of the distance between Target and Best.
% e.g. list_kdtree([[4,9,6],[8,1,3],[6,6,5],[7,1,5]],T), nearest(T,[1,2,9],D,X)
%   gives D=53, X=[7,1,5]
nearest(Tree, Target, DistSqu, Best):-
  Target=[_,_|_],     % We require the dimensionality to be at least 2
  near(Tree, Target, DistSqu0, Point),
  nearest_1(Tree, Target, 0, s(Point,DistSqu0), s(Best,DistSqu)).

nearest_1(void, _, _, Solution, Solution):-!.
nearest_1(Tree, Target, Depth, Solution0, Solution):-
  Tree=kd(Point,_,_),
  distance_squared(Point, Target, D2),
  update_if_nearer(D2, Point, Solution0, Solution1),
  Axis is Depth mod 2,
  subtract(Axis, Point, Target, D1),
  Dsqu is D1 * D1,
  nearest_2(Dsqu, Tree, Target, Depth, Solution1, Solution).

nearest_2(Dsqu, kd(_,Left,Right), Target, Depth, Solution1, Solution):-
  Solution1=s(_,DistSqu1),
  Dsqu < DistSqu1, !,
  Depth1 is Depth + 1,
  nearest_1(Left, Target, Depth1, Solution1, Solution2),
  nearest_1(Right, Target, Depth1, Solution2, Solution).
nearest_2(_, kd(Point,Left,_), Target, Depth, Solution1, Solution):-
  Axis is Depth mod 2,
  lt(Axis, Target, Point), !,
  Depth1 is Depth + 1,
  nearest_1(Left, Target, Depth1, Solution1, Solution).
nearest_2(_, kd(_,_,Right), Target, Depth, Solution1, Solution):-
  Depth1 is Depth + 1,
  nearest_1(Right, Target, Depth1, Solution1, Solution).

% Updates the Solution if the Point is nearer to the Target than the current
%   Solution is.
update_if_nearer(D, Point, s(_,DistSqu0), s(Point,D)):-
  D < DistSqu0, !.
update_if_nearer(_, _, Solution, Solution).

% near(Tree, Target, DistSqu, Near) is true if Near is a point nearest to the
%   Target in the nodes of the kd-tree Tree visited when finding the node to
%   which the Target could be attached, and DistSqu is the square of the
%   distance between Near and the Target.  
% e.g. list_kdtree([[2,3,3],[5,4,1],[9,6,4],[4,7,1]],T), near(T,[9,5,5],D,P)
%   gives D=2, P=[9,6,4] 
near(Tree, Target, DistSqu, Near):-
  near(Tree, Target, 0, s([],2147483647), s(Near,DistSqu)).

near(void, _, _, Solution, Solution).
near(kd(Point,Left,_), Target, Depth, s(_,DistSqu0), Solution):-
  Axis is Depth mod 2,
  lt(Axis, Target, Point),
  distance_squared(Point, Target, DistSqu1), 
  DistSqu1 < DistSqu0, !,
  Depth1 is Depth + 1,
  near(Left, Target, Depth1, s(Point,DistSqu1), Solution).
near(kd(Point,Left,_), Target, Depth, Solution0, Solution):-
  Axis is Depth mod 2,
  lt(Axis, Target, Point), !,
  Depth1 is Depth + 1,
  near(Left, Target, Depth1, Solution0, Solution).
near(kd(Point,_,Right), Target, Depth, s(_,DistSqu0), Solution):-
  distance_squared(Point, Target, DistSqu1), 
  DistSqu1 < DistSqu0, !,
  Depth1 is Depth + 1,
  near(Right, Target, Depth1, s(Point,DistSqu1), Solution).
near(kd(_,_,Right), Target, Depth, Solution0, Solution):-
  Depth1 is Depth + 1,
  near(Right, Target, Depth1, Solution0, Solution).

% merge_sort(List, Axis, SortedList) is true if SortedList is the result of
%   sorting the List of k-dimensional points along the given Axis.
% e.g. merge_sort([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]], 0, X) gives
%        X=[[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]]
%      merge_sort([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]], 1, X) gives
%        X=[[8,1],[7,2],[2,3],[5,4],[9,6],[4,7]]
merge_sort(List, Axis, SortedList):-
  length(List, Length),
  merge_sort_1(Length, Axis, List, [], SortedList).

merge_sort_1(0, _, Rest, Rest, []):-!.
merge_sort_1(1, _, [A|Rest], Rest, [A]):-!.
merge_sort_1(N, Axis, List, Rest, Sorted):-
  N1 is N // 2,
  N2 is N - N1,
  merge_sort_1(N1, Axis, List, TempList, SortedLeft),
  merge_sort_1(N2, Axis, TempList, Rest, SortedRight),
  ordered_merge(SortedLeft, SortedRight, Axis, Sorted).

ordered_merge([], Ys, _, Ys).
ordered_merge([X|Xs], [Y|Ys], Axis, [Y|Zs]):-
  lt(Axis, Y, X), !, ordered_merge([X|Xs], Ys, Axis, Zs).
ordered_merge([X|Xs], Ys, Axis, [X|Zs]):-
  ordered_merge(Xs, Ys, Axis, Zs).

% partition(I, Points, Before, Chosen, After) is true if Chosen is the I-th
%   element of the list Points, Before are the points preceding the I-th
%   element, and After are the points following the I-th element.
partition(1, [Chosen|After], [], Chosen, After):-!.
partition(I, [Point|Points], [Point|Before], Chosen, After):-
  I1 is I - 1,
  partition(I1, Points, Before, Chosen, After).

%  Comparisons and operations on a given Axis of two k-dimensional points.
lt(Axis, Xs, Ys):-
  nth_member(Xs, Axis, X),
  nth_member(Ys, Axis, Y), !,
  X < Y.

subtract(Axis, Xs, Ys, Z):-
  nth_member(Xs, Axis, X),
  nth_member(Ys, Axis, Y), !,
  Z is X - Y.

% nth_member(+Xs, ?N, ?X) is true if X is the N-th (base 0) element of the 
%   list Xs.
nth_member(Xs, N, X):-nth_member_1(Xs, X, 0, N).

nth_member_1([X|_], X, I, I).
nth_member_1([_|Xs], X, I0, I):-
  I1 is I0 + 1,
  nth_member_1(Xs, X, I1, I).
  
% distance_squared(Xs, Ys, DistSqu) is true if DistSqu is the square of the
%   distance between the points Xs and Ys.
distance_squared(Xs, Ys, DistSqu):-
  distance_squared(Xs, Ys, 0, DistSqu).

distance_squared([], [], DistSqu, DistSqu).
distance_squared([X|Xs], [Y|Ys], DistSqu0, DistSqu):-
  X_Y is X-Y,
  DistSqu1 is DistSqu0 + X_Y*X_Y,
  distance_squared(Xs, Ys, DistSqu1, DistSqu).

% naive_nearest(Points, Target, DistSqu, Nearest) is true if Nearest is a point
%   in the list of Points which is nearest to the point Target, and DistSqu is
%   the square of the distance between Target and Nearest.
% e.g. naive_nearest([[4,9,6],[8,1,3],[6,6,5],[7,1,5]], [1,2,9], D, P) gives
%   gives D=53, P=[7,1,5]
naive_nearest(Points, Target, DistSqu, Nearest):-
  naive_nearest_1(Points, Target, s([],2147483647), s(Nearest,DistSqu)).

naive_nearest_1([], _, Solution, Solution).
naive_nearest_1([Point|Points], Target, s(_,DistSqu0), Solution):-
  distance_squared(Point, Target, DistSqu1),
  DistSqu1 =< DistSqu0, !,
  naive_nearest_1(Points, Target, s(Point,DistSqu1), Solution).
naive_nearest_1([_|Points], Target, Solution0, Solution):-
  naive_nearest_1(Points, Target, Solution0, Solution).

% Generates N K-dimensional random points with coordinates between Min and
%   Max inclusive.
% e.g. rand_points(100,3,-99,99,Ps), list_kdtree(Ps,T), nearest(T,[0,0,0],D,X)
% e.g. rand_points(100,3,-99,99,Ps), naive_nearest(Ps,[0,0,0],D,X)
rand_points(0, _, _, _, []):-!.
rand_points(N, K, Min, Max, [Point|Points]):-
  N > 0,
  rand_point(K, Min, Max, Point),
  N1 is N - 1,
  rand_points(N1, K, Min, Max, Points).

% Generates a K-dimensional random point with coordinates between Min and
%   Max inclusive.
rand_point(0, _, _, []):-!.
rand_point(K, Min, Max, [Point|Points]):-
  Min < Max,
  rand_int(Min, Max, Point),
  K1 is K - 1,
  rand_point(K1, Min, Max, Points).

% rand_int(I, J, K) is true if K is a pseudo-random integer in the range I..J.
rand_int(I, J, K):-K is int(rand(J - I + 1)) + I.

% length(Xs, L) is true if L is the number of elements in the list Xs.
%length(Xs, L):-length_1(Xs, 0, L).

% length_1(Xs, L0, L) is true if L is equal to L0 plus the number of elements
%   in the list Xs.
%length_1([], L, L).
%length_1([_|Xs], L0, L):-L1 is L0 + 1, length_1(Xs, L1, L).

LPA Index     Home Page

Valid HTML 4.0!