2015-04-30 24 views

Trả lời

6
import org.apache.spark.mllib.linalg.{Vectors,Vector,Matrix,SingularValueDecomposition,DenseMatrix,DenseVector} 
import org.apache.spark.mllib.linalg.distributed.RowMatrix 

def computeInverse(X: RowMatrix): DenseMatrix = { 
    val nCoef = X.numCols.toInt 
    val svd = X.computeSVD(nCoef, computeU = true) 
    if (svd.s.size < nCoef) { 
    sys.error(s"RowMatrix.computeInverse called on singular matrix.") 
    } 

    // Create the inv diagonal matrix from S 
    val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x,-1)))) 

    // U cannot be a RowMatrix 
    val U = new DenseMatrix(svd.U.numRows().toInt,svd.U.numCols().toInt,svd.U.rows.collect.flatMap(x => x.toArray)) 

    // If you could make V distributed, then this may be better. However its alreadly local...so maybe this is fine. 
    val V = svd.V 
    // inv(X) = V*inv(S)*transpose(U) --- the U is already transposed. 
    (V.multiply(invS)).multiply(U) 
    } 
3

Tôi đã có vấn đề sử dụng chức năng này với tùy chọn

conf.set("spark.sql.shuffle.partitions", "12") 

Các hàng trong RowMatrix bị xáo trộn.

Dưới đây là một bản cập nhật mà làm việc cho tôi

import org.apache.spark.mllib.linalg.{DenseMatrix,DenseVector} 
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix 

def computeInverse(X: IndexedRowMatrix) 
: DenseMatrix = 
{ 
    val nCoef = X.numCols.toInt 
    val svd = X.computeSVD(nCoef, computeU = true) 
    if (svd.s.size < nCoef) { 
    sys.error(s"IndexedRowMatrix.computeInverse called on singular matrix.") 
    } 

    // Create the inv diagonal matrix from S 
    val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x, -1)))) 

    // U cannot be a RowMatrix 
    val U = svd.U.toBlockMatrix().toLocalMatrix().multiply(DenseMatrix.eye(svd.U.numRows().toInt)).transpose 

    val V = svd.V 
    (V.multiply(invS)).multiply(U) 
} 
0

Matrix U trả về bởi X.computeSVD có kích thước mxk nơi m là số hàng của bản gốc (phân phối) RowMatrix X. Một mong chờ m phải lớn (có thể lớn hơn k), vì vậy không nên thu thập nó trong trình điều khiển nếu chúng tôi muốn mã của chúng tôi mở rộng đến các giá trị thực sự lớn là m.

Tôi sẽ nói cả hai giải pháp dưới đây đều bị lỗ hổng này. Câu trả lời được đưa ra bởi @Alexander Kharlamov gọi val U = svd.U.toBlockMatrix().toLocalMatrix() thu thập ma trận trong trình điều khiển. Điều tương tự cũng xảy ra với câu trả lời được đưa ra bởi @Climbs_lika_Spyder (btw đá nick của bạn !!), gọi số svd.U.rows.collect.flatMap(x => x.toArray). Tôi thà đề nghị dựa vào một phép nhân ma trận phân tán như mã Scala được đăng here.

+0

Tôi không thấy bất kỳ phép tính nghịch đảo nào tại liên kết bạn đã thêm. –

+0

@Climbs_lika_Spyder Liên kết là về phép nhân ma trận phân tán để thay thế phép nhân ma trận cục bộ '(V.multiply (invS)) nhân (U)' trong dòng cuối cùng của giải pháp, do đó bạn không cần thu thập 'U' trong tài xế. Tôi nghĩ rằng 'V' và' invS' không đủ lớn để gây ra vấn đề. – Pablo

Các vấn đề liên quan