1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| >>> import torch >>> a = torch.randn(3, 3) >>> a tensor([[ 0.4925, 1.0023, -0.5190], [ 0.0464, -1.3224, -0.0238], [-0.1801, -0.6056, 1.0795]]) >>> torch.tril(a) tensor([[ 0.4925, 0.0000, 0.0000], [ 0.0464, -1.3224, 0.0000], [-0.1801, -0.6056, 1.0795]]) >>> b = torch.randn(4, 6) >>> b tensor([[-0.7886, -0.2559, -0.9161, 0.2353, 0.4033, -0.0633], [-1.1292, -0.3209, -0.3307, 2.0719, 0.9238, -1.8576], [-1.1988, -1.0355, -1.2745, -1.7479, 0.3736, -0.7210], [-0.3380, 1.7570, -1.6608, -0.4785, 0.2950, -1.2821]]) >>> torch.tril(b) tensor([[-0.7886, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [-1.1292, -0.3209, 0.0000, 0.0000, 0.0000, 0.0000], [-1.1988, -1.0355, -1.2745, 0.0000, 0.0000, 0.0000], [-0.3380, 1.7570, -1.6608, -0.4785, 0.0000, 0.0000]]) >>> torch.tril(b, diagonal=1) tensor([[-0.7886, -0.2559, 0.0000, 0.0000, 0.0000, 0.0000], [-1.1292, -0.3209, -0.3307, 0.0000, 0.0000, 0.0000], [-1.1988, -1.0355, -1.2745, -1.7479, 0.0000, 0.0000], [-0.3380, 1.7570, -1.6608, -0.4785, 0.2950, 0.0000]]) >>> torch.tril(b, diagonal=-1) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [-1.1292, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [-1.1988, -1.0355, 0.0000, 0.0000, 0.0000, 0.0000], [-0.3380, 1.7570, -1.6608, 0.0000, 0.0000, 0.0000]])
|