diff --git a/src/domain/stock.py b/src/domain/stock.py index 4d92a3e..1ddea2c 100644 --- a/src/domain/stock.py +++ b/src/domain/stock.py @@ -27,6 +27,15 @@ class StockDict(TypedDict): updated_at: datetime +@dataclass +class StockInfo: + symbol: str + quantity: int + price: float + avg_cost: float + percentage: float + + @dataclass class CreateStock: user_id: int diff --git a/src/tests/test_stock_usecase.py b/src/tests/test_stock_usecase.py index 521f662..683807d 100644 --- a/src/tests/test_stock_usecase.py +++ b/src/tests/test_stock_usecase.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock, ANY, patch from usecase.stock import StockUsecase -from domain.stock import CreateStock, Stock +from domain.stock import CreateStock, Stock, StockInfo from domain.portfolio import Portfolio, Holding, PortfolioInfo from domain.enum import ActionType, StockType @@ -660,3 +660,149 @@ def test_get_portfolio_info_with_valid_holdings(self, mock_get_stock_price, stoc portfolio_repo.get.assert_called_once_with(user_id=user_id) mock_get_stock_price.assert_called_once_with(stock_info=[("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)]) assert result == expected_result + + +class TestStockUsecaseGetStockInfoList: + @patch.object(StockUsecase, "_get_stock_price") + def test_get_stock_info_list_no_portfolio(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio_repo.get.return_value = None + expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} + + # Act + result = usecase.get_stock_info_list(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_stock_info_list_empty_portfolio(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=0.0, + total_money_in=0.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = portfolio + expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} + + # Act + result = usecase.get_stock_info_list(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_stock_info_list_no_valid_holdings(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=2000.0, + holdings=[Holding(symbol="AAPL", shares=0, stock_type=StockType.STOCKS, total_cost=0.0)], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = portfolio + expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} + + # Act + result = usecase.get_stock_info_list(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_stock_info_list_with_valid_holdings(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=2000.0, + holdings=[ + Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), + Holding(symbol="TSLA", shares=20, stock_type=StockType.STOCKS, total_cost=2000.0), + Holding(symbol="SPY", shares=30, stock_type=StockType.ETF, total_cost=3000.0), + Holding(symbol="QQQ", shares=5, stock_type=StockType.ETF, total_cost=1000.0), + ], + created_at=ANY, + updated_at=ANY, + ) + + mock_get_stock_price.return_value = {"AAPL": 200.0, "TSLA": 500.0, "SPY": 400.0, "QQQ": 300.0} + portfolio_repo.get.return_value = portfolio + expected_result = { + StockType.ETF.value: [ + StockInfo( + symbol="SPY", + quantity=30, + price=400.0, + avg_cost=100.0, # 3000 / 30 + percentage=45.0, # 12000 / (1000 + 10*200 + 20*500 + 30*400 + 5*300) + ), + StockInfo( + symbol="QQQ", + quantity=5, + price=300.0, + avg_cost=200.0, # 1000 / 5 + percentage=6.0, # 1500 / (1000 + 10*200 + 20*500 + 30*400 + 5*300) + ), + ], + StockType.STOCKS.value: [ + StockInfo( + symbol="AAPL", + quantity=10, + price=200.0, + avg_cost=150.0, # 1500 / 10 + percentage=8.0, # 2000 / (1000 + 10*200 + 20*500 + 30*400 + 5*300) + ), + StockInfo( + symbol="TSLA", + quantity=20, + price=500.0, + avg_cost=100.0, # 2000 / 20 + percentage=38.0, # 10000 / (1000 + 10*200 + 20*500 + 30*400 + 5*300) + ), + ], + "CASH": [ + StockInfo( + symbol="CASH", + quantity=1, + price=1000.0, + avg_cost=0.0, + percentage=4.0, # 1000 / (1000 + 10*200 + 20*500 + 30*400 + 5*300) + ) + ], + } + + # Act + result = usecase.get_stock_info_list(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_called_once_with( + stock_info=[ + ("AAPL", StockType.STOCKS), + ("TSLA", StockType.STOCKS), + ("SPY", StockType.ETF), + ("QQQ", StockType.ETF), + ] + ) + assert result == expected_result diff --git a/src/usecase/stock.py b/src/usecase/stock.py index c8a4964..d988b7c 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -4,7 +4,7 @@ from .base import AbstractStockUsecase from adapters.base import AbstractStockRepository, AbstractPortfolioRepository from domain.portfolio import Portfolio, Holding, PortfolioInfo -from domain.stock import CreateStock, Stock +from domain.stock import CreateStock, Stock, StockInfo from domain.enum import ActionType, StockType ETF_KEY = "navPrice" @@ -105,6 +105,56 @@ def get_portfolio_info(self, user_id: int) -> PortfolioInfo: roi=roi, ) + def get_stock_info_list(self, user_id: int) -> Dict[str, List[StockInfo]]: + result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} + portfolio = self.portfolio_repo.get(user_id=user_id) + if portfolio is None or portfolio.total_money_in == 0.0: + return result + + valid_holdings = [ + (holding.symbol, holding.shares, holding.stock_type, holding.total_cost) + for holding in portfolio.holdings + if holding.shares > 0 + ] + if not valid_holdings: + return result + + # Fetch prices in batch + stock_info = [(symbol, stock_type) for symbol, _, stock_type, _ in valid_holdings] + stock_price_by_symbol = self._get_stock_price(stock_info=stock_info) + + # Calculate total stock value and total value + total_stock_price = sum( + shares * stock_price_by_symbol.get(symbol, 0.0) for symbol, shares, _, _ in valid_holdings + ) + total_value = total_stock_price + portfolio.cash_balance + + for symbol, shares, stock_type, total_cost in valid_holdings: + stock_price = stock_price_by_symbol.get(symbol, 0.0) + stock_total_value = shares * stock_price + + result[stock_type.value].append( + StockInfo( + symbol=symbol, + quantity=shares, + price=stock_price, + avg_cost=round(total_cost / shares, 2), + percentage=round(stock_total_value / total_value, 2) * 100, + ) + ) + + result["CASH"].append( + StockInfo( + symbol="CASH", + quantity=1, + price=portfolio.cash_balance, + avg_cost=0, + percentage=round(portfolio.cash_balance / total_value, 2) * 100, + ) + ) + + return result + def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, float]: if not stock_info: return {}