Sorting by Schwartzian transform in Scala

The term “Schwartzian transform” was named in honor of Randall Schwartz by Tom Christiansen. (If you recognize either name, you might be thinking, “What would Perl have to do with Scala?” I’ll simply observe that the idea of a multi-paradigm language is an old one.)

At JPR 2008, Dianne Marsh mentioned an article that included a description of the orderby keyword in LINQ. The resulting conversation explored the strategy described below.

The example

Let’s use a minimal Employee class for our discussion:

  class Employee(
      val id:    Int,
      val dept:  String,
      val first: String,
      val last:  String
  ) {
    override def toString() =
      "(" + id + ") " + first + " " + last + " in " + dept
  }

Here’s a list of sample data, used as the basis for a list of employees:

  val data = List(
      (314, "Payroll", "Ursula", "Ficus"),
      (159, "Sales", "Ville", "Eave"),
      (265, "Development", "Wilbur", "Dim"),
      (358, "Sales", "Xeris", "Confabulation"),
      (979, "Development", "Yolanda", "Borg"),
      (323, "Development", "Ziggy", "Ableton")
  )
  val employees =
    for ((id, dp, fn, ln) <- data) yield new Employee(id, dp, fn, ln)

Note that we could mix in the Ordered trait, implementing compare and equals methods, to create a standard ordering for Employee. We won’t use that approach here because we intend to sort employee lists in a variety of ways: alphabetically by name, by employee ID, or even by length of name.

The goal

Our immediate objective is to create a function to sort a list of objects by identifying the sort criteria. The result of the discussion is the sortBy function, which allows us to sort our employees by first name and employee ID by writing as little as:

  sortBy ((e: Employee) => (e.first, e.id)) (employees)

The remainder of this post describes how we compose such a function. There’s certainly room for further enhancement, but getting this far was enough for one conversation (and more than enough for one article).

The standard sort method

Scala’s List class has a pre-defined sort method which uses a caller-provided function to determine whether a pair of elements is in order. To obtain a list sorted by a single Employee attribute, such as id, we can write the expression:

  employees.sort((a, b) => a.id < b.id)

Because it takes only a single argument, the sort method can be written as…

  employees sort ((a, b) => a.id < b.id)

…which is the style I’ll use for the remainder of this article.

Scala also doesn’t require us to specify the type for arguments a and b; the compiler can infer that they have type Employee because they are coming from employees, which has type List[Employee]. (In this case, we could go further and use placeholder syntax to abbreviate the above to…

  employees sort (_.id < _.id)

…but we can’t use that abbreviation for functions that refer to the same parameter more than once.)

As our sorting criteria are more complex (i.e. sorting by length of name, with last name and then first name as tie-breakers), the in-line syntax becomes cumbersome. We can define and use a method which returns Boolean, as in:

  def lessByNameLength(a: Employee, b: Employee): Boolean = {
    val bylength = (a.first + a.last).length - (b.first + b.last).length
    if (bylength == 0) {
      val bylastname = a.last compare b.last
      if (bylastname == 0) {
        a.first < b.first
      } else {
        bylastname < 0
      }
    } else {
      bylength < 0
    }
  }
  // ...
  val employeesByNameLength = employees sort lessByNameLength

That approach provides the per-call flexibility we want, but it has two drawbacks:

  1. it is “boilerplate-heavy”, and
  2. it re-derives the ordering data for each operand, regardless of how many times the operand has been examined previously.

With respect to point #1, the multiple-key comparison requires much more typing, hides our intention in a mound of implementation detail, and is therefore much more error-prone. (For example, if a.first < b.first had been mistyped as a.first > b.first, how long would it take to find the typo-bug?)

We’ll address the boilerplate issue after we discuss the other drawback.

Schwartzian transform

A typical, well-chosen sorting algorithm requires O(n log n) comparisons. The key idea (pun intended) of the Schwartzian transform is to:

  • compute the sort key for each value in the collection and attach it to that value,
  • sort the key/value combinations based on the pre-computed keys, and
  • extract the sorted values from the key/value combinations.

This is a nice application of the time-honored mathematical strategy of transforming a problem to a representation which makes work easier, doing the work, and then transforming the result back (f = g · h · g-1). Each stage can be written and examined individually, a core idea in functional programming. Re #2 above, this approach performs the key computation exactly once per original element, during the setup stage, rather than repeating it every time a pair of elements is compared.

So what do we do with the keys, and how do we use them in the sorting process?

Tuples as keys

A nice feature of Scala tuples is the implicit conversion to Ordered for a tuple whose fields have Ordered types. For example…

  val data = List(
    ("my", 9),
    ("ukulele", 8),
    ("has", 7),
    ("my", 6),
    ("dog", 5),
    ("has", 4),
    ("fleas", 3)
  )
  println(data sort (_ < _))

…produces output of…

  List((dog,5), (fleas,3), (has,4), (has,7), (my,6), (my,9), (ukulele,8))

…in which the values are ordered by the first element of each tuple, with the second element serving to break ties. Because String and Int are both Ordered, Scala can treat Tuple2[String,Int] as Ordered as well.

It’s easy to write a function that produces a key tuple from an instance of Employee, especially if the key is simply made up of fields! Examples of this include:

  def byName       (e: Employee) = (e.last, e.first)
  def byDeptId     (e: Employee) = (e.dept, e.id)
  def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)

That idea, along with map, zip, and the sort method described earlier, provide all the raw materials we need.

Bolting it together

After reviewing the map and zip methods, we’re ready to start assembling a solution.

If Visitor is the standard illustration for OO design patterns, then map probably fills that role for functional programming. It’s the first syllable in “mapReduce“, and also makes an occasional guest appearance when closures are being discussed.

The expression

  employees map byDeptId 

applies byDeptId to each element of employees and returns the results as a List (which I’ve wrapped to fit on the page):

  List((Payroll,314), (Sales,159), (Development,265),
       (Sales,358), (Development,979), (Development,323))

Associating each key with the corresponding employee is the job of zip, which combines two sequences (as if matching the two sides of a zipper). Given two lists, aList and bList, the expression aList zip bList returns a list of pairs (instances of Tuple2), each of which contains an element of aList and the corresponding element of bList. In our case, that means that…

  employees map byDeptId zip employee

…gives us the list of key/value pairs we need to sort our employees by department and then by ID. Because we want to sort only by the key part of each pair, we need to use sort with a function that compares just the keys. The _1 method retrieves the first element of a Tuple2 instance, so we’ll write:

  employees map byDeptId zip employees sort (_._1 < _._1)

With key/value pairs in the desired order, we are ready to extract the values to form the result, which we can do by adding another map to the pipeline:

  employees map byDeptId zip employees sort (_._1 < _._1) map (_._2)

Evaluating that expression gives us a sorted list. Printed one element per line (via the toString method) it contains:

  (265) Wilbur Dim in Development
  (323) Ziggy Ableton in Development
  (979) Yolanda Borg in Development
  (314) Ursula Ficus in Payroll
  (159) Ville Eave in Sales
  (358) Xeris Confabulation in Sales

Getting functional

Now that we can write an expression that does what we want in a specific instance, we want a general function that captures the core pattern. Looking at the last expression above, it’s clear that the key-extraction function and the list of values are the two parts that need to vary as we re-use the idea. The general template is…

  employees map ? zip ? sort (_._1 < _._1) map (_._2)

…so we just need to fill in the blanks in the following function definition:

  def sortBy ? (f: ?) (vs: ?): ? = {
    (vs map f zip vs) sort (_._1 < _._1) map (_._2)
  }

If the values (the vs parameter) are provided in a list of some type, then the result should be a list of the same type. Because that type can very, we will use type parameters. Choosing V for our value type, we get:

  def sortBy[V, ?] (f: V => ?) (vs: List[V]): List[V] = {
    (vs map f zip vs) sort (_._1 < _._1) map (_._2)
  }

Specifying the remaining type parameter is a bit tricky; we need for f to map from type V to type K (for key)…

  def sortBy[V, K ?] (f: V => K) (vs: List[V]): List[V] = {
    (vs map f zip vs) sort (_._1 < _._1) map (_._2)
  }

…however, K must be a type we can sort. As mentioned in the “Tuples as keys” section, Scala has an implicit conversion to Ordered for tuples (for Tuple2 through Tuple9, actually). That means that we must specify that K can be converted implicitly to Ordered[K], which give us the last bit of type incantation:

  def sortBy[V, K <% Ordered[K]] (f: V => K) (vs: List[V]): List[V] = {
    (vs map f zip vs) sort (_._1 < _._1) map (_._2)
  }

While that last detail is not the sort of thing you’d likely want to write throughout your application code, you don’t have to! That last definition is the sort of thing (pun intended ;-) that you’d put in a library (utility jar, etc.) and just use throughout your application. With that function in hand, we’re now free to write…

  val sortedEmps = sortBy ((e: Employee) => (e.first, e.id)) (employees)

…or…

  def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)
  // ...
  val sortedEmps = sortBy (byNameLength)(employees)

…or even…

  def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)
  def sortEmployeesByNameLength = sortBy (byNameLength) _
  // ...
  val sortedEmps = sortEmployeesByNameLength(employees)

…if you’ve read the post on currying, that is.

About these ads
Trackbacks are closed, but you can post a comment.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

Follow

Get every new post delivered to your Inbox.

%d bloggers like this: