Purely Functional Heap Sort in Scala

Inspired by the blog post Purely functional Heap Sort in OCaml, F# and Haskell on The Flying Frog Blog I decided to implement the same datastructure and operations in Scala to get one little step closer to Scala mastery. This code is not idiomatic Scala, but I think it is decent anyway. If you have any comments, feel free to contact me as usual.

Challenge: Rewrite the code to be more object-oriented. (Disclaimer: I don’t think it’s possible to write toList as a method and have the tail recursion optimization kick in, but I’d love to be proven wrong.)

Don’t forget to take a look at the OCaml, F# and Haskell versions and compare. Note the amount of types that Scala can’t infer.

Update: The code has been updated with some improvements suggested by Robbert.

sealed abstract class Heap[+A] { def rank: Int }
case object EmptyHeap extends Heap[Nothing] { def rank = 0}
case class NonEmptyHeap[A](rank: Int, element: A, left: Heap[A], right: Heap[A]) extends Heap[A]

object Heap {
  def apply[A](x: A): Heap[A] =
    this(x, EmptyHeap, EmptyHeap)

  def apply[A](x: A, a: Heap[A], b: Heap[A]): Heap[A] =
    if (a.rank > b.rank)
      NonEmptyHeap(b.rank + 1, x, a, b)
    else
      NonEmptyHeap(a.rank + 1, x, b, a)

  def merge[A <% Ordered[A]](a: Heap[A], b: Heap[A]): Heap[A] =
    (a, b) match {
      case (x, EmptyHeap) => x
      case (EmptyHeap, x) => x
      case (x: NonEmptyHeap[A], y: NonEmptyHeap[A]) =>
        if (x.element >= y.element)
          Heap(x.element, x.left, merge(x.right, y))
        else
          Heap(y.element, y.left, merge(x, y.right))
    }

  def toList[A <% Ordered[A]](heap: Heap[A]) =
    toListWithMemory(List(), heap)

  @annotation.tailrec
  def toListWithMemory[A <% Ordered[A]](memo: List[A], heap: Heap[A]): List[A] =
    heap match {
      case EmptyHeap => memo
      case x: NonEmptyHeap[A] =>
        toListWithMemory(x.element :: memo, merge(x.left, x.right))
    }

  def heapSort[A <% Ordered[A]](xs: Seq[A]): Seq[A] =
    toList(xs.foldLeft(EmptyHeap: Heap[A])((memo, x) => merge(Heap(x), memo)))
}

object HeapSortTest {
  def main(args: Array[String]) = {
    Heap.heapSort(Range(1, 1000000))
    System.out.println("Done!")
  }
}