almost done implementing persistent-vector.sml's 'prependJoin' helper function. (BRANCH case is fully implemented, but I need to implement the LEAF case next.)

This commit is contained in:
2025-12-08 22:23:01 +00:00
parent b6cad904f9
commit 0c668c9f14

View File

@@ -623,9 +623,9 @@ struct
end
| _ => raise Fail "PersistentVector.joinSameDepth: one is BRANCH and other is LEAF"
datatype join_result =
JOIN_APPEND of t
| JOIN_UPDATE of t
datatype append_join_result =
APPEND_JOIN_APPEND of t
| APPEND_JOIN_UPDATE of t
fun appendJoin (left, right, joinDepth, rightLength) =
case left of
@@ -633,7 +633,7 @@ struct
if joinDepth = 0 then
(* base case: should join at this depth *)
if Vector.length nodes + rightLength > maxSize then
JOIN_APPEND right
APPEND_JOIN_APPEND right
else
(case right of
BRANCH (rightNodes, rightSizes) =>
@@ -650,7 +650,7 @@ struct
+ lastLeftSize
)
in
JOIN_UPDATE (BRANCH (nodes, sizes))
APPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end
| LEAF _ =>
raise Fail
@@ -665,7 +665,7 @@ struct
val lastNode = Vector.sub (nodes, lastIdx)
in
case appendJoin (lastNode, right, joinDepth - 1, rightLength) of
JOIN_UPDATE newLast =>
APPEND_JOIN_UPDATE newLast =>
let
val prevSize =
if lastIdx > 0 then Vector.sub (sizes, lastIdx - 1)
@@ -674,17 +674,13 @@ struct
val sizes = Vector.update (sizes, lastIdx, newLastSize)
val nodes = Vector.update (nodes, lastIdx, newLast)
in
JOIN_UPDATE (BRANCH (nodes, sizes))
APPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end
| JOIN_APPEND newNode =>
if Vector.length nodes + rightLength > maxSize then
| APPEND_JOIN_APPEND newNode =>
if Vector.length nodes = maxSize then
(* parent has to append insead as this node
* would exceed capacity if appended here *)
(* todo: I would prefer to take some nodes from right,
* then append those nodes to this one to make it reach
* max capacity, then return update left * right. *)
JOIN_APPEND (BRANCH (#[newNode], #[getMaxSize newNode]))
APPEND_JOIN_APPEND (BRANCH (#[newNode], #[getMaxSize newNode]))
else
let
val prevSize =
@@ -695,13 +691,13 @@ struct
val newNode = #[newNode]
val nodes = Vector.concat [nodes, newNode]
in
JOIN_UPDATE (BRANCH (nodes, sizes))
APPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end
end
| LEAF (items, sizes) =>
(* joinDepth should = 0, and we assume it is *)
if Vector.length items + rightLength > maxSize then
JOIN_APPEND right
APPEND_JOIN_APPEND right
else
(case right of
LEAF (rightItems, rightSizes) =>
@@ -729,7 +725,7 @@ struct
+ leftMaxSize
)
in
JOIN_UPDATE (LEAF (items, sizes))
APPEND_JOIN_UPDATE (LEAF (items, sizes))
end
| BRANCH _ =>
raise Fail
@@ -738,6 +734,77 @@ struct
\but right is BRANCH"
)
datatype prepend_join_result =
PREPEND_JOIN_PREPEND of t
| PREPEND_JOIN_UPDATE of t
fun prependJoin (left, right, joinDepth, leftLength) =
case right of
BRANCH (rightNodes, rightSizes) =>
if joinDepth = 0 then
(* base case: join *)
if Vector.length rightNodes + leftLength > maxSize then
PREPEND_JOIN_PREPEND left
else
(case left of
BRANCH (leftNodes, leftSizes) =>
let
val nodes = Vector.concat [leftNodes, rightNodes]
val maxLeftSize =
Vector.sub (leftSizes, Vector.length leftSizes - 1)
val sizes = Vector.tabulate (Vector.length nodes,
fn i =>
if i < Vector.length leftSizes then
Vector.sub (leftSizes, i)
else
Vector.sub (rightSizes, i - Vector.length leftSizes)
+ maxLeftSize
)
in
PREPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end
| LEAF _ =>
raise Fail
"persistent-vector.sml prependJoin: \
\expected left and right to be BRANCH \
\but right is BRANCH while left is LEAF"
)
else
(* recursive case *)
let
val firstRightNode = Vector.sub (rightNodes, 0)
val firstRightSize = getMaxSize firstRightNode
in
(case prependJoin (left, firstRightNode, joinDepth - 1, leftLength) of
PREPEND_JOIN_UPDATE newFirst =>
let
val newFirstSize = getMaxSize newFirst
val sizeDiff = newFirstSize - firstRightSize
val sizes = Vector.map (fn el => el + sizeDiff) rightSizes
val nodes = Vector.update (rightNodes, 0, newFirst)
in
PREPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end
| PREPEND_JOIN_PREPEND newFirst =>
if Vector.length rightSizes = maxSize then
PREPEND_JOIN_PREPEND (BRANCH (#[newFirst], #[getMaxSize newFirst]))
else
let
val nodes = Vector.concat [#[newFirst], rightNodes]
val newFirstSize = getMaxSize newFirst
val sizes = Vector.tabulate (Vector.length nodes,
fn i =>
if i = 0 then
newFirstSize
else
Vector.sub (rightSizes, i - 1) + newFirstSize
)
in
PREPEND_JOIN_UPDATE (BRANCH (nodes, sizes))
end)
end
fun join (left, right) =
if isEmpty left then
right
@@ -766,8 +833,8 @@ struct
val rightLength = getRootVecLength right
in
case appendJoin (left, right, joinDepth, rightLength) of
JOIN_UPDATE t => t
| JOIN_APPEND newRight =>
APPEND_JOIN_UPDATE t => t
| APPEND_JOIN_APPEND newRight =>
let
val ls = getMaxSize left
val rs = getMaxSize right + ls
@@ -778,6 +845,21 @@ struct
end
end
else
raise Fail "unimplemented"
let
val joinDepth = rightDepth - leftDepth
val leftLength = getRootVecLength left
in
case prependJoin (left, right, joinDepth, leftLength) of
PREPEND_JOIN_UPDATE t => t
| PREPEND_JOIN_PREPEND newLeft =>
let
val ls = getMaxSize newLeft
val rs = getMaxSize right + ls
val sizes = #[ls, rs]
val nodes = #[newLeft, right]
in
BRANCH (nodes, sizes)
end
end
end
end