4 Values whose Sum is 0

Условие

Дано четыре списка чисел A, B, C, D длины N. Найти количество четверок (a, b, c, d) из A x B x C x D, где a + b + c + d = 0.

Ограничения

Размер списков не превосходит 4000, элементы массива лежат в диапазоне от -228 до 228. Ограничение на время работы – 5 секунд, на память – 64 Мб.

Пример входного файла Пример выходного файла
6
-45 22 42 -16
-41 -27 56 30
-36 53 -37 77
-36 30 -75 -46
26 -38 -10 62
-32 -54 -6 45
5

Источник: Southwestern European 2005.

Решение

Анализ

Очевидное решение за O(N^4) не подходит. Можно сделать оптимизацию – перебирать тройку, а оставшийся элемент искать в хэш-таблице (O(N ^ 3)). Аналогично можно перебирать пару, а оставшуюся пару искать в хэш-таблице O(N^2). Но даже квадратичное решение с использованием хэш-таблиц не укладывается в указанные ограничения. Причем основной проблемой становится большая скрытая константа и большой объем памяти, занимаемый таблицей (4000 * 4000 * 4Б = 64МБ).

Очевидно, что от использования хэш-таблиц следует отказаться и подойти к задаче с другой стороны. Легко заметить, что для двух отсортированных массивов (даже очень больших) можно решить данную задачу за O(L), где L длина массивов. Но возникает проблема – отсортировать два массива состоящих из 4000 * 4000 = 16 млн. элементов. В данной задачи ее можно обойти.

Решение

Следует отметить, что наша задача более узкая – большой массив формируется из сумм всевозможных пар из двух небольших массивов, поэтому вместо явного создания массива и последующей сортировки можно формировать элементы в порядке возрастания. Сделать это позволяет использование структуры данных, позволяющей получать минимальный элемент, и обновлять элемент с заданным индексом. Эти операции могут выполнять многие структуры данных за приемлемое время. В данной задаче удобно использовать дерево отрезков (RMQ), выполняющее обновление за O(logN) и получение глобального минимума за O(1).

Теперь на базе RMQ можно создать структуру данных, которая по двум исходным отсортированным массивам A, B сможет перебрать суммы всевозможных пар элементов в порядке возрастания.

Пусть каждому элементу A поставлен в соответствие индекс в массиве B, который указывает на минимальный элемент в B, пара с которым элемента из A еще не рассматривалась. Увеличение индекса позволяет перейти к следующей по возрастанию паре из B с данным элементом из A. Пусть в RMQ изначально хранятся элементы A[i] + B[0], а значения всех индексов равны 1 (при индексации с 0). Теперь минимум в RMQ – сумма текущей пары элементов. Получив в RMQ номер элемента, нужно увеличить его индекс и прибавить к этому элементу разность между B[новый индекс] и B[старый индекс]. Таким образом, перебор будет осуществляться в порядке возрастания сумм и переход требует O(logN) времени.

Теперь следует применить полученную структуру к массивам A, B и массивам {-C[0], -C[1], …, -C[N - 1]}, {-D[0], -D[1], …, -D[N - 1]}. Дальнейшее решение очевидно. Если текущая сумма AB меньше суммы CD, совершается переход к следующей сумме в AB, если больше – в CD. Если достигнут конец хотя бы одного массива, алгоритм заканчивает работу. Иначе, пока суммы равны, необходимо подсчитать их количество в AB и CD, и прибавить произведение полученных значений к ответу. Временная сложность алгоритма O(N^2 log N).

Реализация

Приведем решение на Java:

0000: import java.io.*;
0001: import java.util.*;
0002:
0003: public class Main {
0004:     public static void main(String[] args) throws IOException {
0005:         new Main().run();
0006:     }
0007:
0008:     BufferedReader in;
0009:     PrintWriter out;
0010:     StringTokenizer st = new StringTokenizer("");
0011:     
0012:     int INF = Integer.MAX_VALUE / 2;
0013:
0014:     int N;
0015:     Enumerator e1;
0016:     Enumerator e2;
0017:     
0018:     void run() throws IOException {
0019:         in = new BufferedReader(new InputStreamReader(System.in));
0020:         out = new PrintWriter(System.out);
0021:         
0022:         N = Integer.parseInt(in.readLine());
0023:         int[] a = new int [N];
0024:         int[] b = new int [N];
0025:         int[] c = new int [N];
0026:         int[] d = new int [N];
0027:         for (int i = 0; i < N; i++) {
0028:             StringTokenizer tok = new StringTokenizer(in.readLine());
0029:             a[i] = Integer.parseInt(tok.nextToken());
0030:             b[i] = Integer.parseInt(tok.nextToken());
0031:             c[i] = -Integer.parseInt(tok.nextToken());
0032:             d[i] = -Integer.parseInt(tok.nextToken());
0033:         }
0034:         e1 = new Enumerator(a, b);
0035:         e2 = new Enumerator(c, d);
0036:         
0037:         long ans = 0L;
0038:         while (true) {
0039:             int cur1 = e1.rmq.val[1]; if (cur1 == INF) break;
0040:             int cur2 = e2.rmq.val[1]; if (cur2 == INF) break;
0041:             if (cur1 < cur2) {
0042:                 e1.next();
0043:             } else if (cur1 != cur2) {
0044:                 e2.next();
0045:             } else {
0046:                 int cnt1 = 0;
0047:                 int cnt2 = 0;
0048:                 while (e1.rmq.val[1] == cur1) {
0049:                     cnt1++;
0050:                     e1.next();
0051:                 }
0052:                 while (e2.rmq.val[1] == cur2) {
0053:                     cnt2++;
0054:                     e2.next();
0055:                 }
0056:                 ans += cnt1 * (long) cnt2;
0057:             }
0058:         }
0059:         
0060:         out.println(ans);
0061:         out.close();
0062:     }
0063:     
0064:     class Enumerator {
0065:         int[] a;
0066:         int[] b;
0067:         int[] pnt;
0068:         RMQ rmq;
0069:         
0070:         Enumerator(int[] a, int[] b) {
0071:             Arrays.sort(a);
0072:             Arrays.sort(b);
0073:             this.a = a;
0074:             this.b = b;
0075:             pnt = new int [N];
0076:             rmq = new RMQ(N);
0077:             for (int i = 0; i < N; i++)
0078:                 rmq.val[i + rmq.n] = a[i] + b[0];
0079:             rmq.build();
0080:         }
0081:         
0082:         void next() {
0083:             if (rmq.val[1] != INF) {
0084:                 int i = rmq.ind[1];
0085:                 if (pnt[i] < N - 1) {
0086:                     int old = b[pnt[i]++];
0087:                     int cur = b[pnt[i]];
0088:                     rmq.set(i, rmq.val[i + rmq.n] + cur - old);
0089:                 } else {
0090:                     rmq.set(i, INF);
0091:                 }
0092:             }
0093:         }
0094:     }
0095:     
0096:     class RMQ {
0097:         int n;
0098:         int[] val;
0099:         int[] ind;
0100:         
0101:         RMQ(int n) {
0102:             this.n = n;
0103:             this.val = new int [2 * n];
0104:             this.ind = new int [2 * n];
0105:             for (int i = 0; i < n; i++)
0106:                 ind[i + n] = i;
0107:         }
0108:         
0109:         void build() {
0110:             for (int v = n - 1; v > 0; v--) {
0111:                 int l = (v << 1);
0112:                 int r = l + 1;
0113:                 if (val[l] < val[r]) {
0114:                     val[v] = val[l];
0115:                     ind[v] = ind[l];
0116:                 } else {
0117:                     val[v] = val[r];
0118:                     ind[v] = ind[r];
0119:                 }
0120:             }
0121:         }
0122:
0123:         void set(int i, int nv) {
0124:             val[i += n] = nv;
0125:             for (int v = i >>= 1; v > 0; v >>= 1) {
0126:                 int l = (v << 1);
0127:                 int r = l + 1;
0128:                 if (val[l] < val[r]) {
0129:                     val[v] = val[l];
0130:                     ind[v] = ind[l];
0131:                 } else {
0132:                     val[v] = val[r];
0133:                     ind[v] = ind[r];
0134:                 }
0135:             }
0136:         }
0137:     }
0138: }