Stochastic Nonsense

Put something smart here.

Java8 Improvements

java8 has a bunch of nice improvements, and over the holidays I’ve had time to play with them a bit.

First, say goodbye to requiring Apache Commons for really simple functionality, like joining a string!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import static java.util.stream.Collectors.joining;

import java.util.Arrays;
import java.util.List;

/**
 * 
 */
public class StringUtils {

  public static void main(String[] args){
    List<String> words = Arrays.asList("a", "b", "a", "a", "b", "c", "a1", "a1", "a1");

    // old style print each element of a list: Arrays.toString(result.toArray())
    puts("java6 style %s", Arrays.toString(words.toArray()));
    puts("java8 style [%s]", words.stream().collect(joining(", ")));

  }

  public static void puts(String s){ System.out.println(s); }
  public static void puts(String format, Object... args){ puts(String.format(format, args)); }
}

java8 also massively cleans up some common operations. A common interview question is given an array or list of words, print them in descending order by count, or return the top n sorted by count descending. A standard program to do this may go like this: create a map from string to count; reverse the map to go from count to array of words with that count, then descend to the correct depth.

The dummy data provided has these counts:

1
2
 a a1  b  c
 3  3  2  1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * get the n highest frequency words
 */
public class WordCounts {

public static void main(String[] args){
    String[] words = new String[]{"a", "b", "a", "a", "b", "c", "a1", "a1", "a1"};

    for(int depth = 0; depth < 4; depth++){
      List<String> result = getMostFrequentWords(words, depth);
      puts("depth %d -> %s", depth, Arrays.toString(result.toArray()));
      puts("");
    }
  }

  public static List<String> getMostFrequentWords(String[] words, int depth){
    if(words == null || words.length == 0 || depth <= 0)
      return Collections.emptyList();

    // word -> counts
    HashMap<String, Integer> counts = new HashMap<>();
    for(String word : words){
      if(counts.containsKey(word))
        counts.put(word, counts.get(word) + 1);
      else
        counts.put(word, 1);
    }

    // count -> list of words with that count
    TreeMap<Integer, ArrayList<String>> countmap = new TreeMap<>();
    for(Map.Entry<String, Integer> entry : counts.entrySet()){
      if(countmap.containsKey(entry.getValue()))
        countmap.get(entry.getValue()).add(entry.getKey());
      else {
        ArrayList<String> l = new ArrayList<>();
        l.add(entry.getKey());
        countmap.put(entry.getValue(), l);
      }
    }

    // iterate through treemap to desired depth
    ArrayList<String> result = new ArrayList<>();
    while(result.size() <= depth){
      for(Integer i : countmap.descendingKeySet()){
        ArrayList<String> list = countmap.get(i);
        if (list.size() + result.size() < depth){
          result.addAll(list);
        } else {
          for(String s : list){
            result.add(s);
            if(result.size() == depth)
              return result;
          }
        }
      }
    }
    return result;
  }

  public static void puts(String s){ System.out.println(s); }
  public static void puts(String format, Object... args){ puts(String.format(format, args)); }
}

this will produce output like:

1
2
3
4
5
6
7
depth 0 -> []

depth 1 -> [a1]

depth 2 -> [a1, a]

depth 3 -> [a1, a, b]

Using java8 streams, we can clean up much of this. For starters, creating the map from word –> word count is essentially build in.

1
2
3
  // word -> counts
  Map<String, Long> counts = Arrays.stream(words)
    .collect(Collectors.groupingBy(s -> s, Collectors.counting()));

Java8 also directly supports inverting or reversing a map, replacing the need to either do it by hand or use guava’s bi-directional map. In the common case, where values are unique, this will suffice:

1
2
3
4
  // count -> list of words: reverse the counts map
  Map<Long, String> countmap = counts.entrySet().stream()
    .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
  puts("countmap: %s", countmap);

Unfortunately, in my case that throws an exception because there is more than one word with the same count. So it’s slightly more complicated:

1
2
3
  // count -> list of words: reverse a map with duplicate values, collecting duplicates in an ArrayList
  Map<Long, ArrayList<String>> countmap = counts.entrySet().stream()
  .collect(Collectors.groupingBy(Map.Entry<String, Long>::getValue, Collectors.mapping(Map.Entry<String, Long>::getKey, Collectors.toCollection(ArrayList::new))));

But I really want a treemap, so I can iterate over they keys in order. Fortunately, I can specify which type of map I want

1
2
  TreeMap<Long, ArrayList<String>> countmap = counts.entrySet().stream()
              .collect(Collectors.groupingBy(Map.Entry<String, Long>::getValue, TreeMap::new, Collectors.mapping(Map.Entry<String, Long>::getKey, Collectors.toCollection(ArrayList::new))));

it’s worth noting the python is simpler still…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from collections import defaultdict

def get_most_frequent_words(words, depth):
  if words is None or len(words) == 0 or depth <= 0:
    return []

  counts = defaultdict(lambda: 0)
  for word in words:
    counts[word] += 1

  countmap = defaultdict(lambda: [])
  for word, count in counts.iteritems():
    countmap[count].append(word)

  result = []
  for key in sorted(countmap.keys(), reverse=True):
    if len(result) + len(countmap[key]) < depth:
      result.extend(countmap[key])
    else:
      for w in countmap[key]:
        result.append(w)
        if len(result) == depth:
          return result

  return result


words = ["a", "b", "a", "a", "b", "c", "a1", "a1", "a1"]
for depth in range(0, 4):
  print('depth %d -> [%s]' % (depth, (', '.join(get_most_frequent_words(words, depth)))))
  print('\n')