diff --git a/go.mod b/go.mod index adcfa06c..92f98d2e 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( golang.org/x/crypto v0.25.0 gopkg.in/ini.v1 v1.67.0 gopkg.in/mail.v2 v2.3.1 + xorm.io/builder v0.3.13 xorm.io/xorm v1.3.9 ) @@ -67,5 +68,4 @@ require ( google.golang.org/protobuf v1.34.1 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - xorm.io/builder v0.3.12 // indirect ) diff --git a/go.sum b/go.sum index cd85f45f..dbf4fa91 100644 --- a/go.sum +++ b/go.sum @@ -163,5 +163,7 @@ nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYm rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= xorm.io/builder v0.3.12 h1:ASZYX7fQmy+o8UJdhlLHSW57JDOkM8DNhcAF5d0LiJM= xorm.io/builder v0.3.12/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo= +xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= xorm.io/xorm v1.3.9 h1:TUovzS0ko+IQ1XnNLfs5dqK1cJl1H5uHpWbWqAQ04nU= xorm.io/xorm v1.3.9/go.mod h1:LsCCffeeYp63ssk0pKumP6l96WZcHix7ChpurcLNuMw= diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index 555b1020..8b27373b 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -62,7 +62,14 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.Context) (any, *errs.E return nil, errs.Or(err, errs.ErrOperationFailed) } - totalCount, err := a.transactions.GetTransactionCount(c, uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, transactionCountReq.AmountFilter, transactionCountReq.Keyword) + allTagIds, err := a.getTagIds(transactionCountReq.TagIds) + + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction tag ids error, because %s", err.Error()) + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + totalCount, err := a.transactions.GetTransactionCount(c, uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, allTagIds, transactionCountReq.AmountFilter, transactionCountReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionCountHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) @@ -118,10 +125,17 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (any, *errs.Er return nil, errs.Or(err, errs.ErrOperationFailed) } + allTagIds, err := a.getTagIds(transactionListReq.TagIds) + + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction tag ids error, because %s", err.Error()) + return nil, errs.Or(err, errs.ErrOperationFailed) + } + var totalCount int64 if transactionListReq.WithCount { - totalCount, err = a.transactions.GetTransactionCount(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.AmountFilter, transactionListReq.Keyword) + totalCount, err = a.transactions.GetTransactionCount(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, transactionListReq.AmountFilter, transactionListReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error()) @@ -129,7 +143,7 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (any, *errs.Er } } - transactions, err := a.transactions.GetTransactionsByMaxTime(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.AmountFilter, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true) + transactions, err := a.transactions.GetTransactionsByMaxTime(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, transactionListReq.AmountFilter, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionListHandler] failed to get transactions earlier than \"%d\" for user \"uid:%d\", because %s", transactionListReq.MaxTime, uid, err.Error()) @@ -209,7 +223,14 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (any, *er return nil, errs.Or(err, errs.ErrOperationFailed) } - transactions, err := a.transactions.GetTransactionsInMonthByPage(c, uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, transactionListReq.AmountFilter, transactionListReq.Keyword) + allTagIds, err := a.getTagIds(transactionListReq.TagIds) + + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction tag ids error, because %s", err.Error()) + return nil, errs.Or(err, errs.ErrOperationFailed) + } + + transactions, err := a.transactions.GetTransactionsInMonthByPage(c, uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, transactionListReq.AmountFilter, transactionListReq.Keyword) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionMonthListHandler] failed to get transactions in month \"%d-%d\" for user \"uid:%d\", because %s", transactionListReq.Year, transactionListReq.Month, uid, err.Error()) @@ -969,6 +990,20 @@ func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.Context, categoryI return allCategoryIds, nil } +func (a *TransactionsApi) getTagIds(tagIds string) ([]int64, error) { + if tagIds == "" || tagIds == "0" { + return nil, nil + } + + requestTagIds, err := utils.StringArrayToInt64Array(strings.Split(tagIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrTransactionTagIdInvalid) + } + + return requestTagIds, nil +} + func (a *TransactionsApi) getTransactionTagIds(allTransactionTagIds map[int64][]int64) []int64 { allTagIds := make([]int64, 0, len(allTransactionTagIds)) diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go index add9e922..5e255ad1 100644 --- a/pkg/models/transaction.go +++ b/pkg/models/transaction.go @@ -98,6 +98,7 @@ type TransactionCountRequest struct { Type TransactionDbType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` + TagIds string `form:"tag_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` MaxTime int64 `form:"max_time" binding:"min=0"` @@ -109,6 +110,7 @@ type TransactionListByMaxTimeRequest struct { Type TransactionDbType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` + TagIds string `form:"tag_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` MaxTime int64 `form:"max_time" binding:"min=0"` @@ -128,6 +130,7 @@ type TransactionListInMonthByPageRequest struct { Type TransactionDbType `form:"type" binding:"min=0,max=4"` CategoryIds string `form:"category_ids"` AccountIds string `form:"account_ids"` + TagIds string `form:"tag_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` TrimAccount bool `form:"trim_account"` diff --git a/pkg/models/transaction_tag_index.go b/pkg/models/transaction_tag_index.go index 23e6bc1b..a34dcbf8 100644 --- a/pkg/models/transaction_tag_index.go +++ b/pkg/models/transaction_tag_index.go @@ -3,11 +3,11 @@ package models // TransactionTagIndex represents transaction and transaction tag relation stored in database type TransactionTagIndex struct { TagIndexId int64 `xorm:"PK"` - Uid int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_time) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_id)"` - Deleted bool `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_time) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_id) NOT NULL"` - TagId int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_time)"` + Uid int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_time_tag_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_id)"` + Deleted bool `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_time_tag_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_id) NOT NULL"` + TransactionTime int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_transaction_time_tag_id) NOT NULL"` + TagId int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_time_tag_id)"` TransactionId int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_id) INDEX(IDX_transaction_tag_index_uid_deleted_transaction_id)"` - TransactionTime int64 `xorm:"INDEX(IDX_transaction_tag_index_uid_deleted_tag_id_transaction_time) NOT NULL"` CreatedUnixTime int64 UpdatedUnixTime int64 DeletedUnixTime int64 diff --git a/pkg/services/transactions.go b/pkg/services/transactions.go index 7d598911..9997d0e3 100644 --- a/pkg/services/transactions.go +++ b/pkg/services/transactions.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "xorm.io/builder" "xorm.io/xorm" "github.com/mayswind/ezbookkeeping/pkg/core" @@ -73,11 +74,11 @@ func (s *TransactionService) GetAllTransactions(c *core.Context, uid int64, page // GetAllTransactionsByMaxTime returns all transactions before given time func (s *TransactionService) GetAllTransactionsByMaxTime(c *core.Context, uid int64, maxTransactionTime int64, count int32, noDuplicated bool) ([]*models.Transaction, error) { - return s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, 0, 0, nil, nil, "", "", 1, count, false, noDuplicated) + return s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, 0, 0, nil, nil, nil, "", "", 1, count, false, noDuplicated) } // GetTransactionsByMaxTime returns transactions before given time -func (s *TransactionService) GetTransactionsByMaxTime(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsByMaxTime(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -101,14 +102,20 @@ func (s *TransactionService) GetTransactionsByMaxTime(c *core.Context, uid int64 actualCount++ } - condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, amountFilter, keyword, noDuplicated) - err = s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions) + condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated) + sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) + + if len(tagIds) > 0 { + sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds)) + } + + err = sess.Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions) return transactions, err } // GetTransactionsInMonthByPage returns all transactions in given year and month -func (s *TransactionService) GetTransactionsInMonthByPage(c *core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, amountFilter string, keyword string) ([]*models.Transaction, error) { +func (s *TransactionService) GetTransactionsInMonthByPage(c *core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string) ([]*models.Transaction, error) { if uid <= 0 { return nil, errs.ErrUserIdInvalid } @@ -121,8 +128,14 @@ func (s *TransactionService) GetTransactionsInMonthByPage(c *core.Context, uid i var transactions []*models.Transaction - condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, amountFilter, keyword, true) - err = s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).OrderBy("transaction_time desc").Find(&transactions) + condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) + sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) + + if len(tagIds) > 0 { + sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds)) + } + + err = sess.OrderBy("transaction_time desc").Find(&transactions) transactionsInMonth := make([]*models.Transaction, 0, len(transactions)) @@ -163,17 +176,23 @@ func (s *TransactionService) GetTransactionByTransactionId(c *core.Context, uid // GetAllTransactionCount returns total count of transactions func (s *TransactionService) GetAllTransactionCount(c *core.Context, uid int64) (int64, error) { - return s.GetTransactionCount(c, uid, 0, 0, 0, nil, nil, "", "") + return s.GetTransactionCount(c, uid, 0, 0, 0, nil, nil, nil, "", "") } // GetTransactionCount returns count of transactions -func (s *TransactionService) GetTransactionCount(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, amountFilter string, keyword string) (int64, error) { +func (s *TransactionService) GetTransactionCount(c *core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string) (int64, error) { if uid <= 0 { return 0, errs.ErrUserIdInvalid } - condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, amountFilter, keyword, true) - return s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).Count(&models.Transaction{}) + condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true) + sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...) + + if len(tagIds) > 0 { + sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds)) + } + + return sess.Count(&models.Transaction{}) } // CreateTransaction saves a new transaction to database @@ -1340,7 +1359,7 @@ func (s *TransactionService) GetTransactionMapByList(transactions []*models.Tran return transactionMap } -func (s *TransactionService) getTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) { +func (s *TransactionService) getTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) { condition := "uid=? AND deleted=?" conditionParams := make([]any, 0, 16) conditionParams = append(conditionParams, uid) @@ -1496,6 +1515,26 @@ func (s *TransactionService) getTransactionQueryCondition(uid int64, maxTransact return condition, conditionParams } +func (s *TransactionService) getTransactionQueryByTagIdsCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, tagIds []int64) *builder.Builder { + if len(tagIds) > 0 { + condition := builder.And(builder.Eq{"uid": uid}, builder.Eq{"deleted": false}) + + if maxTransactionTime > 0 { + condition = condition.And(builder.Lte{"transaction_time": maxTransactionTime}) + } + + if minTransactionTime > 0 { + condition = condition.And(builder.Gte{"transaction_time": minTransactionTime}) + } + + condition = condition.And(builder.In("tag_id", tagIds)) + + return builder.Select("transaction_id").From("transaction_tag_index").Where(condition) + } + + return nil +} + func (s *TransactionService) isAccountIdValid(transaction *models.Transaction) error { if transaction.Type == models.TRANSACTION_DB_TYPE_MODIFY_BALANCE { if transaction.RelatedAccountId != 0 && transaction.RelatedAccountId != transaction.AccountId { diff --git a/third-party-dependencies.json b/third-party-dependencies.json index d5ab2471..c50adea5 100644 --- a/third-party-dependencies.json +++ b/third-party-dependencies.json @@ -23,6 +23,12 @@ "url": "https://xorm.io/xorm", "licenseUrl": "https://gitea.com/xorm/xorm/src/tag/v1.3.9/LICENSE" }, + { + "name": "SQL builder", + "copyright": "Copyright (c) 2016 The Xorm Authors", + "url": "https://gitea.com/xorm/builder", + "licenseUrl": "https://gitea.com/xorm/builder/src/tag/v0.3.13/LICENSE" + }, { "name": "Logrus", "copyright": "Copyright (c) 2014 Simon Eskildsen",