본문 바로가기
카테고리 없음

Pytorch Autograd (자동미분) 개념

by 빛나는존재 2024. 8. 5.


Pytorch 에서는 일일이 사용자가 손으로 미분된 식을 구현해서 coding 하지 않아도 자동으로 미분해주는 autograd 라는 기능이 있다. Autograd 라는 pytorch 내의 패키지가 "computational graph" (이건 우리가 흔히 생각하는 그래프라는 막대나 선형그래프 같은 차트가 아니라 노드와 엣지로 구성된 일종의 마인드맵같이 생긴 그래프를 의미) 를 생성한다. 이 때 Tensor 는 node 가 되고 Tensor 간 계산을 수행하는 function 이 edge 가 된다. 

아래 code 에서 autograd 를 적용시킨 간단한 사례를 보여주고 있는데 크게 단계별로 나눠보면


1. Data generation
2. Weight generation 이 때, weight 는 autograd 의 대상이므로 requires_grad = True 로 해놔야 한다.
3. (여기서 부터 learning 시작)
3.1 forward propagation: y값의 예측값을 수식에 따라 계산, 
3.2 Loss 계산
3.3 backward propagation: Autograd 를 사용하여 각 autograd 대상인 모든 tensor (여기서는 a,b,c,d 라는 weight) 에 대한 Loss 의 편미분값들을 계산.
3.4 gradient descent: weight 를 backward propagation 된 값들과 learning rate 를 적용시켜 update

또한, Pytorch 에서는 torch.autograd.Function 의 subclass 를 사용자의 니즈에 맞게 customizing 하게 정의해서 autograd operator 를 정의하고 forward, backward function 을 수행할 수 있다. 아래의 code 에서는 class Legendre polynomial 수식으로 y값을 예측하는 딥러닝 프로세스를 custom autograd function 으로 표현한다. 


torch.autograd.Function 을 상속받은 LegendrePolynomial3 이라는 class 생성하고 
그 안에서 forward, backward propagation 실시함. Def forward() 에는 예측해야 하는 수식 formula 가 들어있고 backward() 에는 미분된 formula 가 들어있음. 
Autograd function 을 상속받긴 하지만 그렇다고 해서 자동으로 미분을 계산해 준다는 뜻은 아니다. 자동으로 되는 것은 chain rule 이 자동으로 적용되는 것이다. (.. 라고 claude 가 설명해주는데 솔직히 이건 이해를 아직 못했다. 자동으로 미분해 주는거 아녔어?)

 

Legendre Polynomial 수식 미분하는 과정 (출처: Chatgpt 에 물어봄)



이렇게 torch.autograd.Function 을 상속받은 class 로 customized 하게 정의한 autograd 를 정의한 딥러닝 코드는 
다음과 같은 순서로 작동한다.

1. class 정의 
2. Data (x, y) 생성 (이건 예제니까 데이터 랜덤하게 생성한거고 실제 상황에서는 데이터셋 불러오는 과정)
3. 예측대상인 Legerendre Polynomial 3차 수식에도  weight (a,b,c,d) 가 들어간다. 이것들은 random 하게 초기화 하면서 이들은 미분의 대상이므로 requires_grad = True로 해줌. 
4. actual learning part
4.1 Forward pass
4.2 Loss calculation
4.3 backward pass (여기서 autograd 사용)
4.4 weight update (learning rate에 각각의 weight로 Loss function을 대상으로 편미분한 값을곱한 값을 이전 weight 값으로부터 빼줘서 새로 update 된 weight 계산)

 


출처: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

반응형