@@ -111,6 +111,60 @@ SparseArray<T> arithOp(const SparseArray<T> &lhs, const Array<T> &rhs,
111111 return out;
112112}
113113
114+ #define SPARSE_ARITH_OP_FUNC_DEF (FUNC ) \
115+ template <typename T> \
116+ FUNC##_def<T> FUNC##_func();
117+
118+ #define SPARSE_ARITH_OP_FUNC (FUNC, TYPE, INFIX ) \
119+ template <> \
120+ FUNC##_def<TYPE> FUNC##_func<TYPE>() { \
121+ return cusparse##INFIX##FUNC; \
122+ }
123+
124+ #if CUDA_VERSION >= 11000
125+
126+ template <typename T>
127+ using csrgeam2_buffer_size_def = cusparseStatus_t (*)(
128+ cusparseHandle_t, int , int , const T *, const cusparseMatDescr_t, int ,
129+ const T *, const int *, const int *, const T *, const cusparseMatDescr_t,
130+ int , const T *, const int *, const int *, const cusparseMatDescr_t,
131+ const T *, const int *, const int *, size_t *);
132+
133+ #define SPARSE_ARITH_OP_BUFFER_SIZE_FUNC_DEF (FUNC ) \
134+ template <typename T> \
135+ FUNC##_buffer_size_def<T> FUNC##_buffer_size_func();
136+
137+ SPARSE_ARITH_OP_BUFFER_SIZE_FUNC_DEF (csrgeam2);
138+
139+ #define SPARSE_ARITH_OP_BUFFER_SIZE_FUNC (FUNC, TYPE, INFIX ) \
140+ template <> \
141+ FUNC##_buffer_size_def<TYPE> FUNC##_buffer_size_func<TYPE>() { \
142+ return cusparse##INFIX##FUNC##_bufferSizeExt; \
143+ }
144+
145+ SPARSE_ARITH_OP_BUFFER_SIZE_FUNC (csrgeam2, float , S);
146+ SPARSE_ARITH_OP_BUFFER_SIZE_FUNC (csrgeam2, double , D);
147+ SPARSE_ARITH_OP_BUFFER_SIZE_FUNC (csrgeam2, cfloat, C);
148+ SPARSE_ARITH_OP_BUFFER_SIZE_FUNC (csrgeam2, cdouble, Z);
149+
150+ template <typename T>
151+ using csrgeam2_def = cusparseStatus_t (*)(cusparseHandle_t, int , int , const T *,
152+ const cusparseMatDescr_t, int ,
153+ const T *, const int *, const int *,
154+ const T *, const cusparseMatDescr_t,
155+ int , const T *, const int *,
156+ const int *, const cusparseMatDescr_t,
157+ T *, int *, int *, void *);
158+
159+ SPARSE_ARITH_OP_FUNC_DEF (csrgeam2);
160+
161+ SPARSE_ARITH_OP_FUNC (csrgeam2, float , S);
162+ SPARSE_ARITH_OP_FUNC (csrgeam2, double , D);
163+ SPARSE_ARITH_OP_FUNC (csrgeam2, cfloat, C);
164+ SPARSE_ARITH_OP_FUNC (csrgeam2, cdouble, Z);
165+
166+ #else
167+
114168template <typename T>
115169using csrgeam_def = cusparseStatus_t (*)(cusparseHandle_t, int , int , const T *,
116170 const cusparseMatDescr_t, int ,
@@ -120,23 +174,15 @@ using csrgeam_def = cusparseStatus_t (*)(cusparseHandle_t, int, int, const T *,
120174 const int *, const cusparseMatDescr_t,
121175 T *, int *, int *);
122176
123- #define SPARSE_ARITH_OP_FUNC_DEF (FUNC ) \
124- template <typename T> \
125- FUNC##_def<T> FUNC##_func();
126-
127177SPARSE_ARITH_OP_FUNC_DEF (csrgeam);
128178
129- #define SPARSE_ARITH_OP_FUNC (FUNC, TYPE, INFIX ) \
130- template <> \
131- FUNC##_def<TYPE> FUNC##_func<TYPE>() { \
132- return cusparse##INFIX##FUNC; \
133- }
134-
135179SPARSE_ARITH_OP_FUNC (csrgeam, float , S);
136180SPARSE_ARITH_OP_FUNC (csrgeam, double , D);
137181SPARSE_ARITH_OP_FUNC (csrgeam, cfloat, C);
138182SPARSE_ARITH_OP_FUNC (csrgeam, cdouble, Z);
139183
184+ #endif
185+
140186template <typename T, af_op_t op>
141187SparseArray<T> arithOp (const SparseArray<T> &lhs, const SparseArray<T> &rhs) {
142188 lhs.eval ();
@@ -163,9 +209,28 @@ SparseArray<T> arithOp(const SparseArray<T> &lhs, const SparseArray<T> &rhs) {
163209 int baseC, nnzC;
164210 int *nnzcDevHostPtr = &nnzC;
165211
212+ T alpha = scalar<T>(1 );
213+ T beta = op == af_sub_t ? scalar<T>(-1 ) : alpha;
214+
215+ #if CUDA_VERSION >= 11000
216+ size_t pBufferSize = 0 ;
217+
218+ csrgeam2_buffer_size_func<T>()(
219+ sparseHandle (), M, N, &alpha, desc, nnzA, lhs.getValues ().get (),
220+ csrRowPtrA, csrColPtrA, &beta, desc, nnzB, rhs.getValues ().get (),
221+ csrRowPtrB, csrColPtrB, desc, NULL , csrRowPtrC, NULL , &pBufferSize);
222+
223+ auto tmpBuffer = createEmptyArray<char >(dim4 (pBufferSize));
224+
225+ CUSPARSE_CHECK (cusparseXcsrgeam2Nnz (
226+ sparseHandle (), M, N, desc, nnzA, csrRowPtrA, csrColPtrA, desc, nnzB,
227+ csrRowPtrB, csrColPtrB, desc, csrRowPtrC, nnzcDevHostPtr,
228+ tmpBuffer.get ()));
229+ #else
166230 CUSPARSE_CHECK (cusparseXcsrgeamNnz (
167231 sparseHandle (), M, N, desc, nnzA, csrRowPtrA, csrColPtrA, desc, nnzB,
168232 csrRowPtrB, csrColPtrB, desc, csrRowPtrC, nnzcDevHostPtr));
233+ #endif
169234 if (NULL != nnzcDevHostPtr) {
170235 nnzC = *nnzcDevHostPtr;
171236 } else {
@@ -181,15 +246,18 @@ SparseArray<T> arithOp(const SparseArray<T> &lhs, const SparseArray<T> &rhs) {
181246
182247 auto outColIdx = createEmptyArray<int >(dim4 (nnzC));
183248 auto outValues = createEmptyArray<T>(dim4 (nnzC));
184-
185- T alpha = scalar<T>(1 );
186- T beta = op == af_sub_t ? scalar<T>(-1 ) : alpha;
187-
249+ #if CUDA_VERSION >= 11000
250+ csrgeam2_func<T>()(sparseHandle (), M, N, &alpha, desc, nnzA,
251+ lhs.getValues ().get (), csrRowPtrA, csrColPtrA, &beta,
252+ desc, nnzB, rhs.getValues ().get (), csrRowPtrB,
253+ csrColPtrB, desc, outValues.get (), csrRowPtrC,
254+ outColIdx.get (), tmpBuffer.get ());
255+ #else
188256 csrgeam_func<T>()(sparseHandle (), M, N, &alpha, desc, nnzA,
189257 lhs.getValues ().get (), csrRowPtrA, csrColPtrA, &beta,
190258 desc, nnzB, rhs.getValues ().get (), csrRowPtrB, csrColPtrB,
191259 desc, outValues.get (), csrRowPtrC, outColIdx.get ());
192-
260+ # endif
193261 SparseArray<T> retVal = createArrayDataSparseArray (
194262 ldims, outValues, outRowIdx, outColIdx, sfmt);
195263 return retVal;
0 commit comments