-
[BOJ 1325] 자바 성능 최적화 (부제: 온몸 비틀기)카테고리 없음 2025. 8. 8. 20:22
해당 문제를 간략하게 설명하면, N개의 노드, M개의 간선을 갖는 방향 그래프에서 임의의 정점 v를 시작 정점으로 그래프 탐색을 수행했을 때 연결요소를 이루는 정점이 가장 많은 시작 정점을 찾는 것이다. 더 쉽게 이야기하면, 그냥 N개의 노드 전체를 시작 정점으로 잡고 DFS/BFS 를 수행하는 문제이다.대부분의 경우 아래의 풀이와 같이 혹은 비슷하게 풀었을 것이다.
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Queue; import java.util.Set; import java.util.Stack; public class Main { public static int maxVal = 0; public static int L = 0; public static void main(String[] args) throws Exception { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); List<List<Integer>> graph = new ArrayList<>(); String[] split = br.readLine().split(" "); int N = Integer.parseInt(split[0]), M = Integer.parseInt(split[1]); for (int i=0; i<=N; i++) { graph.add(new ArrayList<>()); } for (int i=0; i<M; i++) { split = br.readLine().split(" "); int u = Integer.parseInt(split[0]), v = Integer.parseInt(split[1]); graph.get(v).add(u); } int[] count = new int[N+1]; int maxVal = 0; for (int i=1; i<=N; i++) { boolean[] visited = new boolean[N+1]; count[i] = dfs(graph, visited, i); maxVal = Math.max(maxVal, count[i]); } ArrayList<Integer> ansArr = new ArrayList<>(); for (int i=1; i<=N; i++) { if (count[i] == maxVal) { ansArr.add(i); } } for (Integer ans : ansArr) { System.out.printf("%d ", ans); } } public static int dfs(List<List<Integer>> graph, boolean[] visited, int v) { Stack<Integer> stack = new Stack<>(); stack.add(v); visited[v] = true; int ret = 0; while (!stack.isEmpty()) { Integer now = stack.pop(); ret++; for (Integer adj : graph.get(now)) { if (!visited[adj]) { visited[adj] = true; stack.add(adj); } } } return ret; } }이렇게 풀면 대부분의 경우에 시간초과가 날 것이다. 통과가 되도 거의 10초 중후반에 통과가 될 것이다. (이 문제의 시간제한은 5초이고, java 의 경우 11초가 제한시간이다.)
좀 더 고급개념(tarjan strongly connected component?)을 사용해서 멋드러지게 푸신 분들의 풀이도 보이지만, 나는 해당 개념도 모르고 이해하는 것도 오버킬이라 생각했다. 결정적으로, 그 해법 또한 시간복잡도가 dramatic 하게 줄지는 않는 것으로 보았다. (얼핏 확인했을 때 O(NM)?)
어려운거 하지말고, 할 수 있는 것에서 최적화를 시도해보자. 위 코드를 살펴보면 문제가 될만한 부분이 몇가지 있다.
- 그래프 초기화 시 new ArrayList<>() 로 할당
- dfs 를 돌 때마다 new Stack<>() 생성
- dfs 를 돌 때마다 new visited[N+1] 생성
"메모리 할당은 일반적인 연산보다 비싼 행위이니 느려지지 않을까? 객체 생성을 최소화 해보자"
하나씩 격파해보자.
그래프 초기화 시 new ArrayList<>() 로 할당
이 문제는 N이 1만이기에 1만개의 ArrayList 를 생성해야한다. 뿐만 아니라, 간선의 개수는 10만이기 때문에 이 ArrayList 는 길이가 10만까지 증가할 수 있다. initialCapacity 를 주지 않으면 grow(ArrayList 내부에서 사용하는 배열보다 길어지면, 길이가 2배인 배열을 생성하고 Array.copy) 가 최대 12~13회 정도 일어날 것이다.
(20 * sigma 2^n from n=0 to n=13 ~= 32만?)
- grow 로 인해 생기는 메모리 할당요청/overhead 를 줄이기 위해 initialCapacity 를 10만으로 주면 되지 않을까?
- 10k(N) * 100k(M) = 1G 이므로 무조건 OOME 가 발생할 것이다. (자료형(정수 4bytes)까지 고려하면 4GB)
- 아 그러면 그래프를 2차원 리스트로 표현할 때, 모든 정점에 대해 초기화 하지말고, 간선을 가진 노드만 초기화 해주자.
- Map 을 사용하면 그렇게 구현할 수 있겠다.
- Wrapper type 을 반드시 사용해야하는 Collection 보다는 primitive type array 를 사용하는 것이 빠르지 않을까?
- 그 중에서도 자료형의 크기가 작으면 보다 적은 메모리 공간에 다량의 데이터가 올라가니 cache hit 이 증가해서 성능적으로 이점이 있지 않을까?
dfs 를 돌 때마다 new Stack<>() 생성
DFS 를 N번 돌리는 상황이기 때문에, Stack 또한 N번 생성되고 있다. N이 1만일 때 스택이 1만번 생성될 수 있다.
- new Stack<>() 말고 stack.clear() 로 초기화해서 재사용하면 좀 더 효율적이지 않을까?
- java.util.Stack 의 clear api 의 내부 구현은 O(n) 으로 구현되어 있다.
- 그러므로, fixed size stack 을 직접 구현해서 clear 작업을 O(1)으로 만들자
- java.util.Stack 의 clear api 의 내부 구현은 O(n) 으로 구현되어 있다.
dfs 를 돌 때마다 new visited[N+1] 생성
마찬가지로 DFS 를 N번 돌리는 상황이기 때문에 visited 배열도 N번 생성된다.
- visited 의 타입을 short[] 로 선언하고, 방문체크 하는 로직을 약간만 수정해주면 잦은 객체 생성을 피할 수 있다.
위와 같은 발상을 가지고 최적화 작업을 수행했고 그 코드는 아래와 같다.
import java.io.BufferedReader; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; public class Main { public static void main(String[] args) throws Exception { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); String[] split = br.readLine().split(" "); int N = Short.parseShort(split[0]); int M = Integer.parseInt(split[1]); short[][] graph = new short[N+1][]; Map<Short, List<Short>> temp = new HashMap<>(); for (int i=0; i<M; i++) { split = br.readLine().split(" "); Short u = Short.parseShort(split[0]), v = Short.parseShort(split[1]); if (temp.get(v) == null) { temp.put(v, new ArrayList<>(Arrays.asList(u))); } else { temp.get(v).add(u); } } for (Short key : temp.keySet()) { graph[key] = new short[temp.get(key).size()]; for (int i=0; i<temp.get(key).size(); i++) { graph[key][i] = temp.get(key).get(i); } } short[] visited = new short[N+1]; short[] stack = new short[N+1]; short stack_top = 0; short[] count = new short[N+1]; short ret = 0; short maxVal = 0; for (short i=1; i<=N; i++) { stack_top = 0; stack[stack_top++] = i; visited[i] = i; ret = 0; while (stack_top > 0) { int now = stack[--stack_top]; ret++; if (graph[now] != null) { for (short adj : graph[now]) { if (visited[adj] <= (i-1)) { visited[adj] = i; stack[stack_top++] = adj; } } } } count[i] = ret; maxVal = maxVal > count[i] ? maxVal : count[i]; } for (int i=1; i<=N; i++) { if (count[i] == maxVal) { System.out.printf("%d ", i); } } } }이렇게 최적화 했더니 약 3000ms 로 통과할 수 있었다.
Main Factor 분석
10초 -> 3초로 수행시간이 줄어들었는데, 이에 가장 영향을 많이 준 요소는 무엇이었을까?
최적화 포인트 몇가지를 독립 변인으로 두고 실험해본 결과 Map 을 이용해서 Graph 를 필요한만큼만 초기화하는 것은 거의 효과가 없었고.. 가장 큰 시간 단축을 이룬 요소는 primitive type array 를 사용한 것이었다.
간단한 테스트로 그 차이를 확인해보자.
실험1. ArrayList 로 원소 100000000개 추가하기
import java.io.IOException; import java.util.ArrayList; public class Test { public static void main(String[] args) throws IOException { long start = System.currentTimeMillis(); int N = 100000000; ArrayList<Integer> list = new ArrayList<>(); for (int i = 0; i < N; i++) { list.add(i); } System.out.println("After operation1: " + (System.currentTimeMillis() - start)); start = System.currentTimeMillis(); for (int i = 0; i < N; i++) { list.get(i); } System.out.println("After operation1-1: " + (System.currentTimeMillis() - start)); } }출력
After operation1: 5242 After operation1-1: 67실험2. ArrayList 에 initialCapacity 주고 원소 100000000개 추가하기
import java.io.IOException; import java.util.ArrayList; public class gen { public static void main(String[] args) throws IOException { long start = System.currentTimeMillis(); int N = 100000000; start = System.currentTimeMillis(); ArrayList<Object> list2 = new ArrayList<>(N); for (int i = 0; i < N; i++) { list2.add(i); } System.out.println("After operation2: " + (System.currentTimeMillis() - start)); start = System.currentTimeMillis(); for (int i = 0; i < N; i++) { list2.get(i); } System.out.println("After operation2-1: " + (System.currentTimeMillis() - start)); } }출력
After operation2: 1452 After operation2-1: 95실험3. int[] 에 원소 100000000개 추가하기
import java.io.IOException; import java.util.ArrayList; public class gen { public static void main(String[] args) throws IOException { long start = System.currentTimeMillis(); int N = 100000000; start = System.currentTimeMillis(); int[] list3 = new int[N]; for (int i = 0; i < N; i++) { list3[i] = i; } System.out.println("After operation3: " + (System.currentTimeMillis() - start)); start = System.currentTimeMillis(); for (int i = 0; i < N; i++) { int i1 = list3[i]; } System.out.println("After operation3-1: " + (System.currentTimeMillis() - start)); } }출력
After operation3: 142 After operation3-1: 6수치가 너무 dramatic 하게 차이가 나서 실험3번에서 리스트 접근을 통해 i1 을 할당했더라도, unused variable 이다보니 compile time 에 접근안하도록 최적화했나 생각도 들었다.
ArrayList 도 내부적으로는 배열을 들고 있어서, initialCapacity 를 잘주면 grow 에 의한 overhead 를 피할 수 있어 primitive type array 와 속도차이가 많이 안날 것으로 생각했는데 차이가 너무 많이 나서 놀라웠다. Wrapper 타입을 boxing - unboxing 하는데에 오버헤드가 꽤 큰 것으로 추정이 된다. (뇌피셜)
더 정확하게 분석하기 위해 intellij java profiler 를 통해 flame graph 를 관찰/사용해봤지만... 해석하기에 내 지식이 너무 부족하다. 공부를 더 한뒤 나중에 다시 돌아보면 깨닫는게 있을 것 같다.