2017-06-13 25 views
6

Trong NumPy tôi có thể làm một phép nhân ma trận đơn giản như thế này:Làm thế nào để làm dấu chấm sản phẩm của ma trận trong PyTorch

a = numpy.arange(2*3).reshape(3,2) 
b = numpy.arange(2).reshape(2,1) 
print(a) 
print(b) 
print(a.dot(b)) 

Tuy nhiên, khi tôi đang cố gắng này với PyTorch tensors, điều này không làm việc:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2) 
b = torch.Tensor([[2, 1]]).view(2, -1) 
print(a) 
print(a.size()) 

print(b) 
print(b.size()) 

print(torch.dot(a, b)) 

mã này ném các lỗi sau:

RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

Bất kỳ ý tưởng làm thế nào một dấu chấm sản phẩm đơn giản có thể được thực hiện trong P yTorch?

Trả lời

13

Bạn đang tìm kiếm

torch.mm(a,b) 

Lưu ý rằng torch.dot() cư xử khác nhau với np.dot(). Đã có một số cuộc thảo luận về những gì sẽ là mong muốn here. Cụ thể, torch.dot() xử lý cả hai ab dưới dạng véc tơ 1D (bất kể hình dạng ban đầu của chúng) và tính toán sản phẩm bên trong của chúng. Lỗi này được ném, vì hành vi này làm cho a vectơ có chiều dài 6 và b vectơ có độ dài 2; do đó không thể tính toán được sản phẩm bên trong của chúng. Đối với phép nhân ma trận trong PyTorch, sử dụng torch.mm(). Ngược lại, np.dot() của Numpy linh hoạt hơn; nó tính toán sản phẩm bên trong cho mảng 1D và thực hiện phép nhân ma trận cho mảng 2D.

5

Xây dựng về câu trả lời mexmex, nếu bạn muốn làm một phép nhân ma trận bạn có thể làm điều đó trong ba cách:

AB = A.mm(B) # computes A.B (matrix multiplication) 
# or 
AB = torch.mm(A, B) 
# or even simpler 
AB = A @ B # Python 3.5+ 

Đối với nhân tố khôn ngoan, bạn chỉ có thể làm (nếu A và B có cùng hình dạng)

A * B # element-wise matrix multiplication (Hadamard product) 
Các vấn đề liên quan