import yfinance as yf
from datetime import datetime, timedelta
from sqlalchemy.exc import SQLAlchemyError
import pandas as pd
import numpy as np
import talib as ta
import time
from database_manager import DatabaseManager
from models import create_stock_model, get_stock_class
from rci import RCI


def append_stock_data(db: DatabaseManager, code: str) -> None:
    """
    指定した銘柄コードの株価データを取得し、データベースに保存する。

    Args:
        db (DatabaseManager(): データベース管理オブジェクト。
        code (str): 銘柄コード。

    Returns:
        None
    """

    # 翌日の日付を取得
    tomorrow = (datetime.today() + timedelta(days=1)).strftime("%Y-%m-%d")

    # yfinanceから本日分の株価データを取得
    ticker = f"{code}.T"
    ohlcv_df = yf.download(ticker, end=tomorrow)

    # マルチインデックスまたは通常のインデックスに対応
    if isinstance(ohlcv_df.columns, pd.MultiIndex):  # ----------（1）
        # マルチインデックスを平坦化し、カラム名を整理
        ohlcv_df.columns = ohlcv_df.columns.get_level_values(0)

    ohlcv_df = ohlcv_df.reset_index()
    # DateカラムをYYYY-MM-DD形式に変換
    ohlcv_df["Date"] = pd.to_datetime(ohlcv_df["Date"]).dt.date

    # データフレームのカラム名を小文字に変換
    ohlcv_df.columns = [col.lower() for col in ohlcv_df.columns]

    # 銘柄単位のモデルクラス
    StockModel = create_stock_model(f"stock_{code}", db.Base)  # ----------（2）
    db.init_db()

    # ohlcv_df のカラムを使って空のデータフレームを作成
    df = pd.DataFrame(columns=ohlcv_df.columns)

    session = db.get_session()
    stock_datas = session.query(StockModel).all()

    # データが存在する場合はデータフレームを上書き
    if stock_datas:
        # stock_datasの各オブジェクトの属性を辞書形式に変換してリストに格納
        dicts = [vars(stock_data) for stock_data in stock_datas]
        df = pd.DataFrame(dicts).drop("_sa_instance_state", axis=1)

    if not df.empty and "date" in df.columns and "date" in ohlcv_df.columns:
        min_date = df["date"].min()  # df の最小日付を取得
        diff_df = ohlcv_df[
            (ohlcv_df["date"] >= min_date) & (~ohlcv_df["date"].isin(df["date"]))
        ]
    else:
        diff_df = ohlcv_df  # dfが空 または dateカラムがない場合はそのまま

    # print(diff_df)

    # 差分がない場合は更新分データがないものを判断して処理を中断
    if diff_df.empty:
        return

    # データフレームを結合 空でないデータフレームのみ結合
    frames = [df, diff_df]
    temp_df = pd.concat([frame for frame in frames if not frame.empty])

    # dateカラムをインデックスに設定
    temp_df.set_index("date", inplace=True)

    # 終値を取得して各オシレーターを算出
    close = temp_df["close"]

    # 移動平均の算出
    temp_df["ma5"] = ta.SMA(close, timeperiod=5)
    temp_df["ma25"] = ta.SMA(close, timeperiod=25)
    temp_df["ma75"] = ta.SMA(close, timeperiod=75)

    # ボリンジャーバンドの算出
    upper2, _, lower2 = ta.BBANDS(close, timeperiod=25, nbdevup=2, nbdevdn=2, matype=0)
    temp_df["upper2"], temp_df["lower2"] = upper2, lower2

    # MACDの算出
    macd, macdsignal, hist = ta.MACD(
        close, fastperiod=12, slowperiod=26, signalperiod=9
    )
    temp_df["macd"] = macd
    temp_df["macd_signal"] = macdsignal
    temp_df["hist"] = hist

    # RSIの算出
    rsi14 = ta.RSI(close, timeperiod=14)
    rsi28 = ta.RSI(close, timeperiod=28)
    temp_df["rsi14"], temp_df["rsi28"] = rsi14, rsi28

    # RCIの算出
    rci9 = RCI(close, timeperiod=9)
    rci26 = RCI(close, timeperiod=26)
    temp_df["rci9"], temp_df["rci26"] = rci9, rci26

    # NaNをNoneに変換
    temp_df = temp_df.replace({np.nan: None})

    # データベースに保存するデータフレームの差分
    insert_df = temp_df[~temp_df.index.isin(df["date"])]

    # NaNをNoneに変換
    insert_df = insert_df.replace({np.nan: None})

    # データベースに一括保存
    try:
        session.bulk_insert_mappings(
            StockModel,
            [
                {
                    "date": index,
                    **{k.lower(): v for k, v in row.to_dict().items()},
                }
                for index, row in insert_df.iterrows()
            ],
        )
        session.commit()  # コミットしてデータを保存
    except SQLAlchemyError as e:
        session.rollback()
        print(f"Error occurred: {e}")


db_manager = DatabaseManager()
session = db_manager.get_session()

# stocksテーブル 全件取得
stock = get_stock_class(db_manager.Base)
stocks = session.query(stock).all()

for stock in stocks:
    # 株価データを追加
    append_stock_data(db_manager, stock.code)
    time.sleep(10)
