10#include "quest/include/quest.h"
12#include <catch2/catch_test_macros.hpp>
13#include <catch2/generators/catch_generators_range.hpp>
15#include "tests/utils/qvector.hpp"
16#include "tests/utils/qmatrix.hpp"
17#include "tests/utils/cache.hpp"
18#include "tests/utils/compare.hpp"
19#include "tests/utils/convert.hpp"
20#include "tests/utils/evolve.hpp"
21#include "tests/utils/linalg.hpp"
22#include "tests/utils/lists.hpp"
23#include "tests/utils/macros.hpp"
24#include "tests/utils/measure.hpp"
25#include "tests/utils/random.hpp"
34#define TEST_CATEGORY \
35 LABEL_UNIT_TAG "[initialisations]"
38void TEST_ON_CACHED_QUREGS(quregCache quregs,
auto testFunc) {
40 for (
auto& [label, qureg]: quregs) {
42 DYNAMIC_SECTION( label ) {
50void TEST_ON_CACHED_QUREGS(quregCache quregs,
auto apiFunc,
auto refState) {
54 auto testFunc = [&](
Qureg qureg) {
56 REQUIRE_AGREE( qureg, refState );
59 TEST_ON_CACHED_QUREGS(quregs, testFunc);
74 SECTION( LABEL_CORRECTNESS ) {
76 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(),
initBlankState, getRefStatevec()); }
77 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(),
initBlankState, getRefDensmatr()); }
86 SECTION( LABEL_CORRECTNESS ) {
88 qvector refVec = getRefStatevec(); refVec[0] = 1;
89 qmatrix refMat = getRefDensmatr(); refMat[0][0] = 1;
91 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(),
initZeroState, refVec); }
92 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(),
initZeroState, refMat); }
101 SECTION( LABEL_CORRECTNESS ) {
103 int numQubits = getNumCachedQubits();
104 qreal vecElem = 1. / std::sqrt(getPow2(numQubits));
105 qreal matElem = 1. / getPow2(numQubits);
107 qvector refVec = getConstantVector(getPow2(numQubits), vecElem);
108 qmatrix refMat = getConstantMatrix(getPow2(numQubits), matElem);
110 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(),
initPlusState, refVec); }
111 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(),
initPlusState, refMat); }
120 SECTION( LABEL_CORRECTNESS ) {
122 int numQubits = getNumCachedQubits();
123 int numInds = (int) getPow2(numQubits);
124 int stateInd = GENERATE_COPY( range(0,numInds) );
126 qvector refVec = getRefStatevec(); refVec[stateInd] = 1;
127 qmatrix refMat = getRefDensmatr(); refMat[stateInd][stateInd] = 1;
131 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(), apiFunc, refVec); }
132 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), apiFunc, refMat); }
141 SECTION( LABEL_CORRECTNESS ) {
143 qvector refVec = getRefStatevec(); setToDebugState(refVec);
144 qmatrix refMat = getRefDensmatr(); setToDebugState(refMat);
146 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(),
initDebugState, refVec); }
147 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(),
initDebugState, refMat); }
156 SECTION( LABEL_CORRECTNESS ) {
159 GENERATE( range(0,10) );
161 auto testFunc = [&](
Qureg qureg) {
170 REQUIRE( qureg.cpuAmps[0] != qureg.cpuAmps[1] );
173 REQUIRE_AGREE( prob, 1 );
176 REQUIRE_AGREE( purity, 1 );
179 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(), testFunc); }
180 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), testFunc); }
189 SECTION( LABEL_CORRECTNESS ) {
193 GENERATE( range(0,10) );
194 int numPureStates = GENERATE( 1, 2, 10 );
196 auto testFunc = [&](
Qureg qureg) {
205 REQUIRE( qureg.cpuAmps[0] != qureg.cpuAmps[1] );
208 REQUIRE_AGREE( prob, 1 );
211 if (numPureStates == 1)
212 REQUIRE_AGREE( purity, 1 );
214 REQUIRE( purity < 1 );
217 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), testFunc); }
226 SECTION( LABEL_CORRECTNESS ) {
229 qvector refVec = getRandomVector(getPow2(getNumCachedQubits()));
230 qmatrix refMat = getOuterProduct(refVec, refVec);
234 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(), apiFunc, refVec); }
235 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), apiFunc, refMat); }
246 SECTION( LABEL_CORRECTNESS ) {
248 int numTotalAmps = getPow2(getNumCachedQubits());
249 int numSetAmps = GENERATE_COPY( range(0,numTotalAmps+1) );
250 int startInd = GENERATE_COPY( range(0,numTotalAmps-numSetAmps) );
251 qvector amps = getRandomVector(numSetAmps);
253 auto testFunc = [&](
Qureg qureg) {
256 qvector refVec = getRandomVector(numTotalAmps);
257 setQuregToReference(qureg, refVec);
260 setSubVector(refVec, amps, startInd);
265 REQUIRE_AGREE( qureg, refVec );
268 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(), testFunc); }
277 SECTION( LABEL_CORRECTNESS ) {
279 int numTotalRows = getPow2(getNumCachedQubits());
280 int numTotalAmps = numTotalRows * numTotalRows;
283 GENERATE( range(0,1000) );
285 int startInd =
getRandomInt(0, numTotalAmps - numSetAmps);
286 qvector amps = getRandomVector(numSetAmps);
288 auto testFunc = [&](
Qureg qureg) {
291 qmatrix refMat = getRandomMatrix(numTotalRows);
292 setQuregToReference(qureg, refMat);
295 refMat = getTranspose(refMat);
297 refMat = getTranspose(refMat);
303 REQUIRE_AGREE( qureg, refMat );
306 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), testFunc); }
318 SECTION( LABEL_CORRECTNESS ) {
320 int numTotalRowsCols = getPow2(getNumCachedQubits());
323 GENERATE( range(0,1000) );
326 int startRow =
getRandomInt(0, numTotalRowsCols - numSetRows);
327 int startCol =
getRandomInt(0, numTotalRowsCols - numSetCols);
330 qmatrix amps = getRandomNonSquareMatrix(numSetRows, numSetCols);
332 auto testFunc = [&](
Qureg qureg) {
335 qmatrix refMat = getRandomMatrix(numTotalRowsCols);
336 setQuregToReference(qureg, refMat);
339 std::vector<qcomp*> rowPtrs(numSetRows);
340 for (
size_t r=0; r<numSetRows; r++)
341 rowPtrs[r] = amps[r].data();
348 REQUIRE_AGREE( qureg, refMat );
351 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), testFunc); }
361 SECTION( LABEL_CORRECTNESS ) {
363 GENERATE( range(0,10) );
364 qindex dim = getPow2(getNumCachedQubits());
365 qvector refVec = getRandomVector(dim);
366 qmatrix refMat = getRandomMatrix(dim);
369 if (doScalarsAgree(getTrace(refMat), 0))
370 refMat[0][0] += 1/(qreal) dim;
377 refVec = getNormalised(refVec);
380 refMat /= getReferenceProbability(refMat);
382 SECTION( LABEL_STATEVEC ) { TEST_ON_CACHED_QUREGS(getCachedStatevecs(), funcVec, refVec); }
383 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), funcMat, refMat); }
392 SECTION( LABEL_CORRECTNESS ) {
394 GENERATE( range(0,10) );
395 int numQubits = getNumCachedQubits();
396 int numTerms = GENERATE_COPY( 1, numQubits, getPow2(2*numQubits) );
397 PauliStrSum sum = createRandomPauliStrSum(numQubits, numTerms);
398 qmatrix refMat = getMatrix(sum, numQubits);
402 SECTION( LABEL_DENSMATR ) { TEST_ON_CACHED_QUREGS(getCachedDensmatrs(), apiFunc, refMat); }
qreal calcPurity(Qureg qureg)
qreal calcTotalProb(Qureg qureg)
void setDensityQuregFlatAmps(Qureg qureg, qindex startInd, qcomp *amps, qindex numAmps)
void setQuregToReducedDensityMatrix(Qureg out, Qureg in, int *retainQubits, int numRetainQubits)
void setQuregToPauliStrSum(Qureg qureg, PauliStrSum sum)
void setQuregAmps(Qureg qureg, qindex startInd, qcomp *amps, qindex numAmps)
void setQuregToPartialTrace(Qureg out, Qureg in, int *traceOutQubits, int numTraceQubits)
qreal setQuregToRenormalized(Qureg qureg)
void setQuregToClone(Qureg targetQureg, Qureg copyQureg)
void setDensityQuregAmps(Qureg qureg, qindex startRow, qindex startCol, qcomp **amps, qindex numRows, qindex numCols)
void setQuregToSuperposition(qcomp facOut, Qureg out, qcomp fac1, Qureg qureg1, qcomp fac2, Qureg qureg2)
void initArbitraryPureState(Qureg qureg, qcomp *amps)
void initRandomPureState(Qureg qureg)
void initPlusState(Qureg qureg)
void initZeroState(Qureg qureg)
void initPureState(Qureg qureg, Qureg pure)
void initDebugState(Qureg qureg)
void initRandomMixedState(Qureg qureg, qindex numPureStates)
void initClassicalState(Qureg qureg, qindex ind)
void initBlankState(Qureg qureg)
void syncQuregFromGpu(Qureg qureg)
void setSubMatrix(qmatrix &dest, qmatrix sub, size_t r, size_t c)
int getRandomInt(int min, int maxExcl)
TEST_CASE("initBlankState", TEST_CATEGORY)