Sorting by Schwartzian transform in Scala

The term “Schwartzian transform” was named in honor of Randall Schwartz by Tom Christiansen. (If you recognize either name, you might be thinking, “What would Perl have to do with Scala?” I’ll simply observe that the idea of a multi-paradigm language is an old one.)

At JPR 2008, Dianne Marsh mentioned an article that included a description of the `orderby` keyword in LINQ. The resulting conversation explored the strategy described below.

The example

Let’s use a minimal `Employee` class for our discussion:

```  class Employee(
val id:    Int,
val dept:  String,
val first: String,
val last:  String
) {
override def toString() =
"(" + id + ") " + first + " " + last + " in " + dept
}
```

Here’s a list of sample data, used as the basis for a list of employees:

```  val data = List(
(314, "Payroll", "Ursula", "Ficus"),
(159, "Sales", "Ville", "Eave"),
(265, "Development", "Wilbur", "Dim"),
(358, "Sales", "Xeris", "Confabulation"),
(979, "Development", "Yolanda", "Borg"),
(323, "Development", "Ziggy", "Ableton")
)
```
```  val employees =
for ((id, dp, fn, ln) <- data) yield new Employee(id, dp, fn, ln)
```

Note that we could mix in the `Ordered` trait, implementing `compare` and `equals` methods, to create a standard ordering for `Employee`. We won’t use that approach here because we intend to sort employee lists in a variety of ways: alphabetically by name, by employee ID, or even by length of name.

The goal

Our immediate objective is to create a function to sort a list of objects by identifying the sort criteria. The result of the discussion is the `sortBy` function, which allows us to sort our employees by first name and employee ID by writing as little as:

```  sortBy ((e: Employee) => (e.first, e.id)) (employees)
```

The remainder of this post describes how we compose such a function. There’s certainly room for further enhancement, but getting this far was enough for one conversation (and more than enough for one article).

The standard `sort` method

Scala’s `List` class has a pre-defined `sort` method which uses a caller-provided function to determine whether a pair of elements is in order. To obtain a list sorted by a single `Employee` attribute, such as `id`, we can write the expression:

```  employees.sort((a, b) => a.id < b.id)
```

Because it takes only a single argument, the `sort` method can be written as…

```  employees sort ((a, b) => a.id < b.id)
```

…which is the style I’ll use for the remainder of this article.

Scala also doesn’t require us to specify the type for arguments `a` and `b`; the compiler can infer that they have type `Employee` because they are coming from `employees`, which has type `List[Employee]`. (In this case, we could go further and use placeholder syntax to abbreviate the above to…

`  employees sort (_.id < _.id)`

…but we can’t use that abbreviation for functions that refer to the same parameter more than once.)

As our sorting criteria are more complex (i.e. sorting by length of name, with last name and then first name as tie-breakers), the in-line syntax becomes cumbersome. We can define and use a method which returns `Boolean`, as in:

```  def lessByNameLength(a: Employee, b: Employee): Boolean = {
val bylength = (a.first + a.last).length - (b.first + b.last).length
if (bylength == 0) {
val bylastname = a.last compare b.last
if (bylastname == 0) {
a.first < b.first
} else {
bylastname < 0
}
} else {
bylength < 0
}
}
// ...
val employeesByNameLength = employees sort lessByNameLength
```

That approach provides the per-call flexibility we want, but it has two drawbacks:

1. it is “boilerplate-heavy”, and
2. it re-derives the ordering data for each operand, regardless of how many times the operand has been examined previously.

With respect to point #1, the multiple-key comparison requires much more typing, hides our intention in a mound of implementation detail, and is therefore much more error-prone. (For example, if `a.first < b.first` had been mistyped as `a.first > b.first`, how long would it take to find the typo-bug?)

We’ll address the boilerplate issue after we discuss the other drawback.

Schwartzian transform

A typical, well-chosen sorting algorithm requires O(n log n) comparisons. The key idea (pun intended) of the Schwartzian transform is to:

• compute the sort key for each value in the collection and attach it to that value,
• sort the key/value combinations based on the pre-computed keys, and
• extract the sorted values from the key/value combinations.

This is a nice application of the time-honored mathematical strategy of transforming a problem to a representation which makes work easier, doing the work, and then transforming the result back (f = g · h · g-1). Each stage can be written and examined individually, a core idea in functional programming. Re #2 above, this approach performs the key computation exactly once per original element, during the setup stage, rather than repeating it every time a pair of elements is compared.

So what do we do with the keys, and how do we use them in the sorting process?

Tuples as keys

A nice feature of Scala tuples is the implicit conversion to `Ordered` for a tuple whose fields have `Ordered` types. For example…

```  val data = List(
("my", 9),
("ukulele", 8),
("has", 7),
("my", 6),
("dog", 5),
("has", 4),
("fleas", 3)
)
println(data sort (_ < _))
```

…produces output of…

```  List((dog,5), (fleas,3), (has,4), (has,7), (my,6), (my,9), (ukulele,8))
```

…in which the values are ordered by the first element of each tuple, with the second element serving to break ties. Because `String` and `Int` are both `Ordered`, Scala can treat `Tuple2[String,Int]` as `Ordered` as well.

It’s easy to write a function that produces a key tuple from an instance of `Employee`, especially if the key is simply made up of fields! Examples of this include:

```  def byName       (e: Employee) = (e.last, e.first)
def byDeptId     (e: Employee) = (e.dept, e.id)
def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)
```

That idea, along with `map`, `zip`, and the `sort` method described earlier, provide all the raw materials we need.

Bolting it together

After reviewing the `map` and `zip` methods, we’re ready to start assembling a solution.

If `Visitor` is the standard illustration for OO design patterns, then `map` probably fills that role for functional programming. It’s the first syllable in “mapReduce“, and also makes an occasional guest appearance when closures are being discussed.

The expression

```  employees map byDeptId
```

applies `byDeptId` to each element of `employees` and returns the results as a `List` (which I’ve wrapped to fit on the page):

```  List((Payroll,314), (Sales,159), (Development,265),
(Sales,358), (Development,979), (Development,323))
```

Associating each key with the corresponding employee is the job of `zip`, which combines two sequences (as if matching the two sides of a zipper). Given two lists, `aList` and `bList`, the expression `aList zip bList` returns a list of pairs (instances of `Tuple2`), each of which contains an element of `aList` and the corresponding element of `bList`. In our case, that means that…

```  employees map byDeptId zip employee
```

…gives us the list of key/value pairs we need to sort our employees by department and then by ID. Because we want to sort only by the key part of each pair, we need to use `sort` with a function that compares just the keys. The `_1` method retrieves the first element of a `Tuple2` instance, so we’ll write:

```  employees map byDeptId zip employees sort (_._1 < _._1)
```

With key/value pairs in the desired order, we are ready to extract the values to form the result, which we can do by adding another map to the pipeline:

```  employees map byDeptId zip employees sort (_._1 < _._1) map (_._2)
```

Evaluating that expression gives us a sorted list. Printed one element per line (via the `toString` method) it contains:

```  (265) Wilbur Dim in Development
(323) Ziggy Ableton in Development
(979) Yolanda Borg in Development
(314) Ursula Ficus in Payroll
(159) Ville Eave in Sales
(358) Xeris Confabulation in Sales
```

Getting functional

Now that we can write an expression that does what we want in a specific instance, we want a general function that captures the core pattern. Looking at the last expression above, it’s clear that the key-extraction function and the list of values are the two parts that need to vary as we re-use the idea. The general template is…

```  employees map ? zip ? sort (_._1 < _._1) map (_._2)
```

…so we just need to fill in the blanks in the following function definition:

```  def sortBy ? (f: ?) (vs: ?): ? = {
(vs map f zip vs) sort (_._1 < _._1) map (_._2)
}
```

If the values (the `vs` parameter) are provided in a list of some type, then the result should be a list of the same type. Because that type can very, we will use type parameters. Choosing `V` for our value type, we get:

```  def sortBy[V, ?] (f: V => ?) (vs: List[V]): List[V] = {
(vs map f zip vs) sort (_._1 < _._1) map (_._2)
}
```

Specifying the remaining type parameter is a bit tricky; we need for `f` to map from type `V` to type `K` (for key)…

```  def sortBy[V, K ?] (f: V => K) (vs: List[V]): List[V] = {
(vs map f zip vs) sort (_._1 < _._1) map (_._2)
}
```

…however, `K` must be a type we can sort. As mentioned in the “Tuples as keys” section, Scala has an implicit conversion to `Ordered` for tuples (for `Tuple2` through `Tuple9`, actually). That means that we must specify that `K` can be converted implicitly to `Ordered[K]`, which give us the last bit of type incantation:

```  def sortBy[V, K <% Ordered[K]] (f: V => K) (vs: List[V]): List[V] = {
(vs map f zip vs) sort (_._1 < _._1) map (_._2)
}
```

While that last detail is not the sort of thing you’d likely want to write throughout your application code, you don’t have to! That last definition is the sort of thing (pun intended 😉 that you’d put in a library (utility jar, etc.) and just use throughout your application. With that function in hand, we’re now free to write…

```  val sortedEmps = sortBy ((e: Employee) => (e.first, e.id)) (employees)
```

…or…

```  def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)
// ...
val sortedEmps = sortBy (byNameLength)(employees)
```

…or even…

```  def byNameLength (e: Employee) = ((e.first + e.last).length, e.last, e.first)
def sortEmployeesByNameLength = sortBy (byNameLength) _
// ...
val sortedEmps = sortEmployeesByNameLength(employees)
```

…if you’ve read the post on currying, that is.

Advertisements
Trackbacks are closed, but you can post a comment.