add loss container
This commit is contained in:
9
engine/util/container.py
Normal file
9
engine/util/container.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class LossContainer:
|
||||
def __init__(self, weight, loss):
|
||||
self.weight = weight
|
||||
self.loss = loss
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.weight > 0:
|
||||
return self.weight * self.loss(*args, **kwargs)
|
||||
return 0.0
|
||||
Reference in New Issue
Block a user