Java Lambda从入门到精通二十 自定义收集器Collectors

我们在前面的章节,将流中的数据经过处理后,收集最终结果数据的时候,一般都是用stream的collect()方法。

collect()方法中我们之前用的最多的就是toList()方法,还可以使用如下的映射方法:

  • toCollection(Supplier collectionFactory)
  • toConcurrentMap(Function<? super T,? extends K> keyMapper, Function<? super T,? extends U> valueMapper)
  • toList()
  • toMap(Function<? super T,? extends K> keyMapper, Function<? super T,? extends U> valueMapper)
  • toSet()
  • mapping(Function<? super T,? extends U> mapper, Collector<? super U,A,R> downstream)

这是Java API为我们定义好的方法,如果这些方法不满足,或者收集器的逻辑特别复杂,需要提升性能,那么可以自定义收集器。

要实现自定义收集器,需要实现Collector接口。

我们看一下Collector接口需要实现的方法:

  • 创建新的结果容器( supplier() )
  • 将新的数据元素并入结果容器( accumulator() )
  • 将两个结果容器组合成一个( combiner() )
  • 在容器上执行可选的最终变换( finisher() )

之所以要实现这些方法,是因为收集器内部,是依次调用这些方法,将数据处理并收集起来的。收集器先使用supplier()方法,创建一个存放结果的容器;再使用accumulator()去依次处理流中元素,这个操作其实就是规约,并将结果存放到上一步创建的容器中;然后再使用finisher() 方法将前两步的结果转换为最终要输出的结果类型;combiner()只是定义了并发处理时的逻辑,即并发如何合并各个子集数据的逻辑。

还有一个characteristics()方法,定义了收集器的行为,包含三种行为:UNORDERED-规约不受顺序影响,CONCURRENT-支持并发,IDENTITY_FINISH-表示累加器方法的结果能否直接作为规约结果。

我们来看一个例子,还是用上一节学生成绩分区统计优秀成绩的例子。

回顾上一节代码:

List<Student> list = new ArrayList<>();
String[] names = {"小明","小华","小志","小东","小李","小张","小王","小周","小吴","小郑"};
for(int i=0; i<10; i++) {
	Student student = Student.builder()
			.name(names[i])
			.age(12+i/5)
			.classId(i%3+1)
			.no(i+1)
			.math(85d+i).build();
	list.add(student);
}

//分区,数学成绩大于等于90则为优秀
Map<Boolean, List<Student>> partitioningBy = list.stream().collect(partitioningBy(x -> x.getMath() >= 90));
String collect3 = partitioningBy.get(true).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("数学成绩优秀的学生:"+collect3);
String collect4 = partitioningBy.get(false).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("数学成绩非优秀的学生:"+collect4);

输出结果:

数学成绩优秀的学生:小张,小王,小周,小吴,小郑
数学成绩非优秀的学生:小明,小华,小志,小东,小李

上面就是分区求学生数学成绩大于90和小于90的两组人数。

下面我们实现一个自定义的收集器,代码:

import java.util.*;
import java.util.function.*;
import java.util.stream.Collector;

public class MapToStudentCollector implements Collector<Student, HashMap<Boolean, List<Student>>,HashMap<Boolean, List<Student>>> {
    @Override
    public Supplier<HashMap<Boolean, List<Student>>> supplier() {
        return ()->new HashMap<Boolean, List<Student>>() {{
            put(true, new ArrayList<>());
            put(false, new ArrayList<>());
        }};
    }

    @Override
    public BiConsumer<HashMap<Boolean, List<Student>>, Student> accumulator() {
        return (map,t)->{
            if (t.getMath() >= 90) {
                map.get(true).add(t);
            } else {
                map.get(false).add(t);
            }
        };
    }

    @Override
    public BinaryOperator<HashMap<Boolean, List<Student>>> combiner() {
        BinaryOperator<HashMap<Boolean, List<Student>>> combiner =
                (HashMap<Boolean, List<Student>> m1,
                 HashMap<Boolean, List<Student>> m2) -> {
                    m1.get(true).addAll(m2.get(true));
                    m1.get(false).addAll(m2.get(false));
            return m1;
        };
        return combiner;
    }

    @Override
    public Function<HashMap<Boolean, List<Student>>, HashMap<Boolean, List<Student>>> finisher() {
        return Function.identity();
    }

    @Override
    public Set<Characteristics> characteristics() {
        return Collections.unmodifiableSet(EnumSet.of(Characteristics.IDENTITY_FINISH));
    }
}

测试代码:

List<Student> list = new ArrayList<>();
String[] names = {"小明","小华","小志","小东","小李","小张","小王","小周","小吴","小郑"};
for(int i=0; i<10; i++) {
	Student student = Student.builder()
			.name(names[i])
			.age(12+i/5)
			.classId(i%3+1)
			.no(i+1)
			.math(85d+i).build();
	list.add(student);
}


HashMap<Boolean, List<Student>> mapToStudentCollector = list.stream().collect(new MapToStudentCollector());
String collector1 = mapToStudentCollector.get(true).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("自定义收集器-数学成绩优秀的学生:"+collector1);
String collector2 = mapToStudentCollector.get(false).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("自定义收集器-数学成绩非优秀的学生:"+collector2);

输出结果:

自定义收集器-数学成绩优秀的学生:小张,小王,小周,小吴,小郑
自定义收集器-数学成绩非优秀的学生:小明,小华,小志,小东,小李

上面就是自定义收集器的一个demo,在stream中还定义了一个collect并附带三个参数的方法,能够直接实现相当于自定义收集器的效果。

collect(Supplier<R> supplier, BiConsumer<R,? super T> accumulator, BiConsumer<R,R> combiner)

我们上面写的自定义收集器等同于下面的代码:

HashMap<Boolean, List<Student>> collect = list.stream().collect(() ->
				new HashMap<Boolean, List<Student>>() {{
					put(true, new ArrayList<>());
					put(false, new ArrayList<>());
				}}
		,
		(a, b) -> {
			if (b.getMath() >= 90) {
				a.get(true).add(b);
			} else {
				a.get(false).add(b);
			}
		},
		(m1, m2) -> {
			m1.get(true).addAll(m2.get(true));
			m1.get(false).addAll(m2.get(false));
		}
);
String collect1 = collect.get(true).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("数学成绩优秀的学生:"+collect1);
String collect2 = collect.get(false).stream().map(Student::getName).collect(Collectors.joining(","));
System.out.println("数学成绩非优秀的学生:"+collect2);

输出结果:

数学成绩优秀的学生:小张,小王,小周,小吴,小郑
数学成绩非优秀的学生:小明,小华,小志,小东,小李