BCEWithLogitsLoss Here are some additional explanations on using Binary Cross-Entropy Loss with Pytorch in python.
I have been asked recently on How to find the code BCEWithLogitsLoss (https://discuss.pytorch.org/t/implementation-of-binary-cross-entropy/98715/2) for a function called in PYTORCH with the function handle_torch_function it takes me some time to understand it, and i will share it with you if it helps:
The question was about BCEWithLogitsLoss = BCELoss + sigmoid() ? My answer can be apply if you want to analyse the code of all the functions that figures in the ret dictionnary from https://github.com/pytorch/pytorch/blob/master/torch/overrides.py
BCEWithLogitsLoss is a combination of BCELOSS + a Sigmoid layer i. This is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, it takes advantage of the log-sum-exp trick for numerical stability see : https://en.wikipedia.org/wiki/LogSumExp
BCEWithLogitsLoss in details
- The code of the BCEWithLogitsLoss Class can be found in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.binary_cross_entropy_with_logits(input, target, self.weight, pos_weight=self.pos_weight, reduction=self.reduction)
The F oject is imported from functionnal.py here : https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
You will find the function called
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
It calls the handle_torch_function in https://github.com/pytorch/pytorch/blob/master/torch/overrides.py
You will find an entry of the function binary_cross_entropy_with_logits in the ret dictionnary wich contain every function that can be overriden in pytorch.
This is the Python implementation of torch_function
More info in https://github.com/pytorch/pytorch/issues/24015
Then the code called is in the C++ File
Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t ...