00001 00033 #ifndef _MSC_VER 00034 # include <itpp/config.h> 00035 #else 00036 # include <itpp/config_msvc.h> 00037 #endif 00038 00039 #if defined(HAVE_LAPACK) 00040 # include <itpp/base/lapack.h> 00041 #endif 00042 00043 #include <itpp/base/ls_solve.h> 00044 00045 00046 namespace itpp { 00047 00048 // ----------- ls_solve_chol ----------------------------------------------------------- 00049 00050 #if defined(HAVE_LAPACK) 00051 00052 bool ls_solve_chol(const mat &A, const vec &b, vec &x) 00053 { 00054 int n, lda, ldb, nrhs, info; 00055 n = lda = ldb = A.rows(); 00056 nrhs = 1; 00057 char uplo='U'; 00058 00059 it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00060 it_assert1(n == b.size(), "The number of rows in A must equal the length of b!"); 00061 00062 ivec ipiv(n); 00063 x = b; 00064 mat Chol = A; 00065 00066 dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info); 00067 00068 return (info==0); 00069 } 00070 00071 00072 bool ls_solve_chol(const mat &A, const mat &B, mat &X) 00073 { 00074 int n, lda, ldb, nrhs, info; 00075 n = lda = ldb = A.rows(); 00076 nrhs = B.cols(); 00077 char uplo='U'; 00078 00079 it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00080 it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!"); 00081 00082 ivec ipiv(n); 00083 X = B; 00084 mat Chol = A; 00085 00086 dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info); 00087 00088 return (info==0); 00089 } 00090 00091 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x) 00092 { 00093 int n, lda, ldb, nrhs, info; 00094 n = lda = ldb = A.rows(); 00095 nrhs = 1; 00096 char uplo='U'; 00097 00098 it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00099 it_assert1(n == b.size(), "The number of rows in A must equal the length of b!"); 00100 00101 ivec ipiv(n); 00102 x = b; 00103 cmat Chol = A; 00104 00105 zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info); 00106 00107 return (info==0); 00108 } 00109 00110 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X) 00111 { 00112 int n, lda, ldb, nrhs, info; 00113 n = lda = ldb = A.rows(); 00114 nrhs = B.cols(); 00115 char uplo='U'; 00116 00117 it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00118 it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!"); 00119 00120 ivec ipiv(n); 00121 X = B; 00122 cmat Chol = A; 00123 00124 zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info); 00125 00126 return (info==0); 00127 } 00128 00129 #else 00130 00131 bool ls_solve_chol(const mat &A, const vec &b, vec &x) 00132 { 00133 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00134 return false; 00135 } 00136 00137 bool ls_solve_chol(const mat &A, const mat &B, mat &X) 00138 { 00139 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00140 return false; 00141 } 00142 00143 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x) 00144 { 00145 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00146 return false; 00147 } 00148 00149 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X) 00150 { 00151 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00152 return false; 00153 } 00154 00155 #endif // HAVE_LAPACK 00156 00157 vec ls_solve_chol(const mat &A, const vec &b) 00158 { 00159 vec x; 00160 bool info; 00161 info = ls_solve_chol(A, b, x); 00162 it_assert1(info, "ls_solve_chol: Failed solving the system"); 00163 return x; 00164 } 00165 00166 mat ls_solve_chol(const mat &A, const mat &B) 00167 { 00168 mat X; 00169 bool info; 00170 info = ls_solve_chol(A, B, X); 00171 it_assert1(info, "ls_solve_chol: Failed solving the system"); 00172 return X; 00173 } 00174 00175 cvec ls_solve_chol(const cmat &A, const cvec &b) 00176 { 00177 cvec x; 00178 bool info; 00179 info = ls_solve_chol(A, b, x); 00180 it_assert1(info, "ls_solve_chol: Failed solving the system"); 00181 return x; 00182 } 00183 00184 cmat ls_solve_chol(const cmat &A, const cmat &B) 00185 { 00186 cmat X; 00187 bool info; 00188 info = ls_solve_chol(A, B, X); 00189 it_assert1(info, "ls_solve_chol: Failed solving the system"); 00190 return X; 00191 } 00192 00193 00194 // --------- ls_solve --------------------------------------------------------------- 00195 #if defined(HAVE_LAPACK) 00196 00197 bool ls_solve(const mat &A, const vec &b, vec &x) 00198 { 00199 int n, lda, ldb, nrhs, info; 00200 n = lda = ldb = A.rows(); 00201 nrhs = 1; 00202 00203 it_assert1(A.cols() == n, "ls_solve: System-matrix is not square"); 00204 it_assert1(n == b.size(), "The number of rows in A must equal the length of b!"); 00205 00206 ivec ipiv(n); 00207 x = b; 00208 mat LU = A; 00209 00210 dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info); 00211 00212 return (info==0); 00213 } 00214 00215 bool ls_solve(const mat &A, const mat &B, mat &X) 00216 { 00217 int n, lda, ldb, nrhs, info; 00218 n = lda = ldb = A.rows(); 00219 nrhs = B.cols(); 00220 00221 it_assert1(A.cols() == n, "ls_solve: System-matrix is not square"); 00222 it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!"); 00223 00224 ivec ipiv(n); 00225 X = B; 00226 mat LU = A; 00227 00228 dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info); 00229 00230 return (info==0); 00231 } 00232 00233 bool ls_solve(const cmat &A, const cvec &b, cvec &x) 00234 { 00235 int n, lda, ldb, nrhs, info; 00236 n = lda = ldb = A.rows(); 00237 nrhs = 1; 00238 00239 it_assert1(A.cols() == n, "ls_solve: System-matrix is not square"); 00240 it_assert1(n == b.size(), "The number of rows in A must equal the length of b!"); 00241 00242 ivec ipiv(n); 00243 x = b; 00244 cmat LU = A; 00245 00246 zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info); 00247 00248 return (info==0); 00249 } 00250 00251 bool ls_solve(const cmat &A, const cmat &B, cmat &X) 00252 { 00253 int n, lda, ldb, nrhs, info; 00254 n = lda = ldb = A.rows(); 00255 nrhs = B.cols(); 00256 00257 it_assert1(A.cols() == n, "ls_solve: System-matrix is not square"); 00258 it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!"); 00259 00260 ivec ipiv(n); 00261 X = B; 00262 cmat LU = A; 00263 00264 zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info); 00265 00266 return (info==0); 00267 } 00268 00269 #else 00270 00271 bool ls_solve(const mat &A, const vec &b, vec &x) 00272 { 00273 it_error("LAPACK library is needed to use ls_solve() function"); 00274 return false; 00275 } 00276 00277 bool ls_solve(const mat &A, const mat &B, mat &X) 00278 { 00279 it_error("LAPACK library is needed to use ls_solve() function"); 00280 return false; 00281 } 00282 00283 bool ls_solve(const cmat &A, const cvec &b, cvec &x) 00284 { 00285 it_error("LAPACK library is needed to use ls_solve() function"); 00286 return false; 00287 } 00288 00289 bool ls_solve(const cmat &A, const cmat &B, cmat &X) 00290 { 00291 it_error("LAPACK library is needed to use ls_solve() function"); 00292 return false; 00293 } 00294 00295 #endif // HAVE_LAPACK 00296 00297 vec ls_solve(const mat &A, const vec &b) 00298 { 00299 vec x; 00300 bool info; 00301 info = ls_solve(A, b, x); 00302 it_assert1(info, "ls_solve: Failed solving the system"); 00303 return x; 00304 } 00305 00306 mat ls_solve(const mat &A, const mat &B) 00307 { 00308 mat X; 00309 bool info; 00310 info = ls_solve(A, B, X); 00311 it_assert1(info, "ls_solve: Failed solving the system"); 00312 return X; 00313 } 00314 00315 cvec ls_solve(const cmat &A, const cvec &b) 00316 { 00317 cvec x; 00318 bool info; 00319 info = ls_solve(A, b, x); 00320 it_assert1(info, "ls_solve: Failed solving the system"); 00321 return x; 00322 } 00323 00324 cmat ls_solve(const cmat &A, const cmat &B) 00325 { 00326 cmat X; 00327 bool info; 00328 info = ls_solve(A, B, X); 00329 it_assert1(info, "ls_solve: Failed solving the system"); 00330 return X; 00331 } 00332 00333 00334 // ----------------- ls_solve_od ------------------------------------------------------------------ 00335 #if defined(HAVE_LAPACK) 00336 00337 bool ls_solve_od(const mat &A, const vec &b, vec &x) 00338 { 00339 int m, n, lda, ldb, nrhs, lwork, info; 00340 char trans='N'; 00341 m = lda = ldb = A.rows(); 00342 n = A.cols(); 00343 nrhs = 1; 00344 lwork = n + std::max(m,nrhs); 00345 00346 it_assert1(m >= n, "The system is under-determined!"); 00347 it_assert1(m == b.size(), "The number of rows in A must equal the length of b!"); 00348 00349 vec work(lwork); 00350 x = b; 00351 mat QR = A; 00352 00353 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00354 x.set_size(n, true); 00355 00356 return (info==0); 00357 } 00358 00359 bool ls_solve_od(const mat &A, const mat &B, mat &X) 00360 { 00361 int m, n, lda, ldb, nrhs, lwork, info; 00362 char trans='N'; 00363 m = lda = ldb = A.rows(); 00364 n = A.cols(); 00365 nrhs = B.cols(); 00366 lwork = n + std::max(m,nrhs); 00367 00368 it_assert1(m >= n, "The system is under-determined!"); 00369 it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!"); 00370 00371 vec work(lwork); 00372 X = B; 00373 mat QR = A; 00374 00375 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00376 X.set_size(n, nrhs, true); 00377 00378 return (info==0); 00379 } 00380 00381 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x) 00382 { 00383 int m, n, lda, ldb, nrhs, lwork, info; 00384 char trans='N'; 00385 m = lda = ldb = A.rows(); 00386 n = A.cols(); 00387 nrhs = 1; 00388 lwork = n + std::max(m,nrhs); 00389 00390 it_assert1(m >= n, "The system is under-determined!"); 00391 it_assert1(m == b.size(), "The number of rows in A must equal the length of b!"); 00392 00393 cvec work(lwork); 00394 x = b; 00395 cmat QR = A; 00396 00397 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00398 x.set_size(n, true); 00399 00400 return (info==0); 00401 } 00402 00403 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X) 00404 { 00405 int m, n, lda, ldb, nrhs, lwork, info; 00406 char trans='N'; 00407 m = lda = ldb = A.rows(); 00408 n = A.cols(); 00409 nrhs = B.cols(); 00410 lwork = n + std::max(m,nrhs); 00411 00412 it_assert1(m >= n, "The system is under-determined!"); 00413 it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!"); 00414 00415 cvec work(lwork); 00416 X = B; 00417 cmat QR = A; 00418 00419 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00420 X.set_size(n, nrhs, true); 00421 00422 return (info==0); 00423 } 00424 00425 #else 00426 00427 bool ls_solve_od(const mat &A, const vec &b, vec &x) 00428 { 00429 it_error("LAPACK library is needed to use ls_solve_od() function"); 00430 return false; 00431 } 00432 00433 bool ls_solve_od(const mat &A, const mat &B, mat &X) 00434 { 00435 it_error("LAPACK library is needed to use ls_solve_od() function"); 00436 return false; 00437 } 00438 00439 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x) 00440 { 00441 it_error("LAPACK library is needed to use ls_solve_od() function"); 00442 return false; 00443 } 00444 00445 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X) 00446 { 00447 it_error("LAPACK library is needed to use ls_solve_od() function"); 00448 return false; 00449 } 00450 00451 #endif // HAVE_LAPACK 00452 00453 vec ls_solve_od(const mat &A, const vec &b) 00454 { 00455 vec x; 00456 bool info; 00457 info = ls_solve_od(A, b, x); 00458 it_assert1(info, "ls_solve_od: Failed solving the system"); 00459 return x; 00460 } 00461 00462 mat ls_solve_od(const mat &A, const mat &B) 00463 { 00464 mat X; 00465 bool info; 00466 info = ls_solve_od(A, B, X); 00467 it_assert1(info, "ls_solve_od: Failed solving the system"); 00468 return X; 00469 } 00470 00471 cvec ls_solve_od(const cmat &A, const cvec &b) 00472 { 00473 cvec x; 00474 bool info; 00475 info = ls_solve_od(A, b, x); 00476 it_assert1(info, "ls_solve_od: Failed solving the system"); 00477 return x; 00478 } 00479 00480 cmat ls_solve_od(const cmat &A, const cmat &B) 00481 { 00482 cmat X; 00483 bool info; 00484 info = ls_solve_od(A, B, X); 00485 it_assert1(info, "ls_solve_od: Failed solving the system"); 00486 return X; 00487 } 00488 00489 // ------------------- ls_solve_ud ----------------------------------------------------------- 00490 #if defined(HAVE_LAPACK) 00491 00492 bool ls_solve_ud(const mat &A, const vec &b, vec &x) 00493 { 00494 int m, n, lda, ldb, nrhs, lwork, info; 00495 char trans='N'; 00496 m = lda = A.rows(); 00497 n = A.cols(); 00498 ldb = n; 00499 nrhs = 1; 00500 lwork = m + std::max(n,nrhs); 00501 00502 it_assert1(m < n, "The system is over-determined!"); 00503 it_assert1(m == b.size(), "The number of rows in A must equal the length of b!"); 00504 00505 vec work(lwork); 00506 x = b; 00507 x.set_size(n, true); 00508 mat QR = A; 00509 00510 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00511 00512 return (info==0); 00513 } 00514 00515 bool ls_solve_ud(const mat &A, const mat &B, mat &X) 00516 { 00517 int m, n, lda, ldb, nrhs, lwork, info; 00518 char trans='N'; 00519 m = lda = A.rows(); 00520 n = A.cols(); 00521 ldb = n; 00522 nrhs = B.cols(); 00523 lwork = m + std::max(n,nrhs); 00524 00525 it_assert1(m < n, "The system is over-determined!"); 00526 it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!"); 00527 00528 vec work(lwork); 00529 X = B; 00530 X.set_size(n, std::max(m, nrhs), true); 00531 mat QR = A; 00532 00533 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00534 X.set_size(n, nrhs, true); 00535 00536 return (info==0); 00537 } 00538 00539 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x) 00540 { 00541 int m, n, lda, ldb, nrhs, lwork, info; 00542 char trans='N'; 00543 m = lda = A.rows(); 00544 n = A.cols(); 00545 ldb = n; 00546 nrhs = 1; 00547 lwork = m + std::max(n,nrhs); 00548 00549 it_assert1(m < n, "The system is over-determined!"); 00550 it_assert1(m == b.size(), "The number of rows in A must equal the length of b!"); 00551 00552 cvec work(lwork); 00553 x = b; 00554 x.set_size(n, true); 00555 cmat QR = A; 00556 00557 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00558 00559 return (info==0); 00560 } 00561 00562 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X) 00563 { 00564 int m, n, lda, ldb, nrhs, lwork, info; 00565 char trans='N'; 00566 m = lda = A.rows(); 00567 n = A.cols(); 00568 ldb = n; 00569 nrhs = B.cols(); 00570 lwork = m + std::max(n,nrhs); 00571 00572 it_assert1(m < n, "The system is over-determined!"); 00573 it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!"); 00574 00575 cvec work(lwork); 00576 X = B; 00577 X.set_size(n, std::max(m, nrhs), true); 00578 cmat QR = A; 00579 00580 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00581 X.set_size(n, nrhs, true); 00582 00583 return (info==0); 00584 } 00585 00586 #else 00587 00588 bool ls_solve_ud(const mat &A, const vec &b, vec &x) 00589 { 00590 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00591 return false; 00592 } 00593 00594 bool ls_solve_ud(const mat &A, const mat &B, mat &X) 00595 { 00596 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00597 return false; 00598 } 00599 00600 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x) 00601 { 00602 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00603 return false; 00604 } 00605 00606 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X) 00607 { 00608 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00609 return false; 00610 } 00611 00612 #endif // HAVE_LAPACK 00613 00614 00615 vec ls_solve_ud(const mat &A, const vec &b) 00616 { 00617 vec x; 00618 bool info; 00619 info = ls_solve_ud(A, b, x); 00620 it_assert1(info, "ls_solve_ud: Failed solving the system"); 00621 return x; 00622 } 00623 00624 mat ls_solve_ud(const mat &A, const mat &B) 00625 { 00626 mat X; 00627 bool info; 00628 info = ls_solve_ud(A, B, X); 00629 it_assert1(info, "ls_solve_ud: Failed solving the system"); 00630 return X; 00631 } 00632 00633 cvec ls_solve_ud(const cmat &A, const cvec &b) 00634 { 00635 cvec x; 00636 bool info; 00637 info = ls_solve_ud(A, b, x); 00638 it_assert1(info, "ls_solve_ud: Failed solving the system"); 00639 return x; 00640 } 00641 00642 cmat ls_solve_ud(const cmat &A, const cmat &B) 00643 { 00644 cmat X; 00645 bool info; 00646 info = ls_solve_ud(A, B, X); 00647 it_assert1(info, "ls_solve_ud: Failed solving the system"); 00648 return X; 00649 } 00650 00651 00652 // ---------------------- backslash ----------------------------------------- 00653 00654 bool backslash(const mat &A, const vec &b, vec &x) 00655 { 00656 int m=A.rows(), n=A.cols(); 00657 bool info; 00658 00659 if (m == n) 00660 info = ls_solve(A,b,x); 00661 else if (m > n) 00662 info = ls_solve_od(A,b,x); 00663 else 00664 info = ls_solve_ud(A,b,x); 00665 00666 return info; 00667 } 00668 00669 00670 vec backslash(const mat &A, const vec &b) 00671 { 00672 vec x; 00673 bool info; 00674 info = backslash(A, b, x); 00675 it_assert1(info, "backslash(): solution was not found"); 00676 return x; 00677 } 00678 00679 00680 bool backslash(const mat &A, const mat &B, mat &X) 00681 { 00682 int m=A.rows(), n=A.cols(); 00683 bool info; 00684 00685 if (m == n) 00686 info = ls_solve(A, B, X); 00687 else if (m > n) 00688 info = ls_solve_od(A, B, X); 00689 else 00690 info = ls_solve_ud(A, B, X); 00691 00692 return info; 00693 } 00694 00695 00696 mat backslash(const mat &A, const mat &B) 00697 { 00698 mat X; 00699 bool info; 00700 info = backslash(A, B, X); 00701 it_assert1(info, "backslash(): solution was not found"); 00702 return X; 00703 } 00704 00705 00706 bool backslash(const cmat &A, const cvec &b, cvec &x) 00707 { 00708 int m=A.rows(), n=A.cols(); 00709 bool info; 00710 00711 if (m == n) 00712 info = ls_solve(A,b,x); 00713 else if (m > n) 00714 info = ls_solve_od(A,b,x); 00715 else 00716 info = ls_solve_ud(A,b,x); 00717 00718 return info; 00719 } 00720 00721 00722 cvec backslash(const cmat &A, const cvec &b) 00723 { 00724 cvec x; 00725 bool info; 00726 info = backslash(A, b, x); 00727 it_assert1(info, "backslash(): solution was not found"); 00728 return x; 00729 } 00730 00731 00732 bool backslash(const cmat &A, const cmat &B, cmat &X) 00733 { 00734 int m=A.rows(), n=A.cols(); 00735 bool info; 00736 00737 if (m == n) 00738 info = ls_solve(A, B, X); 00739 else if (m > n) 00740 info = ls_solve_od(A, B, X); 00741 else 00742 info = ls_solve_ud(A, B, X); 00743 00744 return info; 00745 } 00746 00747 cmat backslash(const cmat &A, const cmat &B) 00748 { 00749 cmat X; 00750 bool info; 00751 info = backslash(A, B, X); 00752 it_assert1(info, "backslash(): solution was not found"); 00753 return X; 00754 } 00755 00756 00757 // -------------------------------------------------------------------------- 00758 00759 vec forward_substitution(const mat &L, const vec &b) 00760 { 00761 int n = L.rows(); 00762 vec x(n); 00763 00764 forward_substitution(L, b, x); 00765 00766 return x; 00767 } 00768 00769 void forward_substitution(const mat &L, const vec &b, vec &x) 00770 { 00771 it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(), 00772 "forward_substitution: dimension mismatch" ); 00773 int n = L.rows(), i, j; 00774 double temp; 00775 00776 x(0)=b(0)/L(0,0); 00777 for (i=1;i<n;i++) { 00778 // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow. 00779 //i_pos=i*L._row_offset(); 00780 temp=0; 00781 for (j=0; j<i; j++) { 00782 temp += L._elem(i,j) * x(j); 00783 //temp+=L._data()[i_pos+j]*x(j); 00784 } 00785 x(i) = (b(i)-temp)/L._elem(i,i); 00786 //x(i)=(b(i)-temp)/L._data()[i_pos+i]; 00787 } 00788 } 00789 00790 vec forward_substitution(const mat &L, int p, const vec &b) 00791 { 00792 int n = L.rows(); 00793 vec x(n); 00794 00795 forward_substitution(L, p, b, x); 00796 00797 return x; 00798 } 00799 00800 void forward_substitution(const mat &L, int p, const vec &b, vec &x) 00801 { 00802 it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows()/2, 00803 "forward_substitution: dimension mismatch"); 00804 int n = L.rows(), i, j; 00805 00806 x=b; 00807 00808 for (j=0;j<n;j++) { 00809 x(j)/=L(j,j); 00810 for (i=j+1;i<std::min(j+p+1,n);i++) { 00811 x(i)-=L(i,j)*x(j); 00812 } 00813 } 00814 } 00815 00816 vec backward_substitution(const mat &U, const vec &b) 00817 { 00818 vec x(U.rows()); 00819 backward_substitution(U, b, x); 00820 00821 return x; 00822 } 00823 00824 void backward_substitution(const mat &U, const vec &b, vec &x) 00825 { 00826 it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(), 00827 "backward_substitution: dimension mismatch" ); 00828 int n = U.rows(), i, j; 00829 double temp; 00830 00831 x(n-1)=b(n-1)/U(n-1,n-1); 00832 for (i=n-2; i>=0; i--) { 00833 // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow. 00834 temp=0; 00835 //i_pos=i*U._row_offset(); 00836 for (j=i+1; j<n; j++) { 00837 temp += U._elem(i,j) * x(j); 00838 //temp+=U._data()[i_pos+j]*x(j); 00839 } 00840 x(i) = (b(i)-temp)/U._elem(i,i); 00841 //x(i)=(b(i)-temp)/U._data()[i_pos+i]; 00842 } 00843 } 00844 00845 vec backward_substitution(const mat &U, int q, const vec &b) 00846 { 00847 vec x(U.rows()); 00848 backward_substitution(U, q, b, x); 00849 00850 return x; 00851 } 00852 00853 void backward_substitution(const mat &U, int q, const vec &b, vec &x) 00854 { 00855 it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows()/2, 00856 "backward_substitution: dimension mismatch" ); 00857 int n = U.rows(), i, j; 00858 00859 x=b; 00860 00861 for (j=n-1; j>=0; j--) { 00862 x(j) /= U(j,j); 00863 for (i=std::max(0,j-q); i<j; i++) { 00864 x(i)-=U(i,j)*x(j); 00865 } 00866 } 00867 } 00868 00869 } // namespace itpp
Generated on Fri Jun 8 02:08:51 2007 for IT++ by Doxygen 1.5.2