from app import create_app
from typing import Any
import yfinance as yf
import pandas as pd
import talib as ta
import numpy as np
from sqlalchemy.exc import SQLAlchemyError
from datetime import datetime, timedelta
import time
from database import DatabaseManager
from models import create_stock_model, get_stock_class
from util import RCI
import sys

app = create_app()


# yfinanceから株価データを取得してテーブルに追加
def get_stock(db: Any, code: str) -> None:
    """
    指定された銘柄コードに対して株価データを取得し、各種テクニカル指標を計算後、
    データベースに保存する処理を行う関数

    Args:
        db (Any): データベース管理オブジェクト（例: DatabaseManager）
        code (str): 株価指数のシンボル
    """
    # 銘柄単位のモデルクラス
    StockModel = create_stock_model(f"stock_{code}", db.Base)
    db.init_db()

    # 開始日付
    start = "2019-01-01"

    # 今日の日付を取得
    today = datetime.today().strftime("%Y-%m-%d")
    # 翌日の日付を取得
    tomorrow = (datetime.today() + timedelta(days=1)).strftime("%Y-%m-%d")

    # yfinanceから本日分の株価データを取得
    ticker = f"{code}.T"
    ohlcv_df = yf.download(ticker, start=start, end=tomorrow)

    # マルチインデックスまたは通常のインデックスに対応
    if isinstance(ohlcv_df.columns, pd.MultiIndex):
        # マルチインデックスを平坦化し、カラム名を整理
        ohlcv_df.columns = ohlcv_df.columns.get_level_values(0)

    ohlcv_df = ohlcv_df.reset_index()
    # DateカラムをYYYY-MM-DD形式に変換
    ohlcv_df["Date"] = pd.to_datetime(ohlcv_df["Date"]).dt.date

    # データフレームのカラム名を小文字に変換
    ohlcv_df.columns = [col.lower() for col in ohlcv_df.columns]

    if len(ohlcv_df) == 0:
        return

    # 銘柄単位のモデルクラス
    StockModel = create_stock_model(f"stock_{code}", db.Base)
    db.init_db()

    # ohlcv_df のカラムを使って空のデータフレームを作成
    df = pd.DataFrame(columns=ohlcv_df.columns)

    session = db.get_session()
    stock_datas = session.query(StockModel).all()

    # データが存在する場合はデータフレームを上書き
    if stock_datas:
        # stock_datasの各オブジェクトの属性を辞書形式に変換してリストに格納
        dicts = [vars(stock_data) for stock_data in stock_datas]
        df = pd.DataFrame(dicts).drop("_sa_instance_state", axis=1)

    # データフレームの差分
    diff_df = ohlcv_df[~ohlcv_df["date"].isin(df["date"])] if not df.empty else ohlcv_df

    # 差分がない場合は更新分データがないものを判断して処理を中断
    if diff_df.empty:
        return

    # データフレームを結合 空でないデータフレームのみ結合
    frames = [df, diff_df]
    temp_df = pd.concat([frame for frame in frames if not frame.empty])

    # dateカラムをインデックスに設定
    temp_df.set_index("date", inplace=True)

    # 終値を取得して各オシレーターを算出
    close = temp_df["close"]

    # 移動平均の算出
    temp_df["ma5"] = ta.SMA(close, timeperiod=5)
    temp_df["ma25"] = ta.SMA(close, timeperiod=25)
    temp_df["ma75"] = ta.SMA(close, timeperiod=75)

    # ボリンジャーバンドの算出
    upper2, _, lower2 = ta.BBANDS(close, timeperiod=25, nbdevup=2, nbdevdn=2, matype=0)
    temp_df["upper2"], temp_df["lower2"] = upper2, lower2

    # MACDの算出
    macd, macdsignal, hist = ta.MACD(
        close, fastperiod=12, slowperiod=26, signalperiod=9
    )
    temp_df["macd"] = macd
    temp_df["macd_signal"] = macdsignal
    temp_df["hist"] = hist

    # RSIの算出
    rsi14 = ta.RSI(close, timeperiod=14)
    rsi28 = ta.RSI(close, timeperiod=28)
    temp_df["rsi14"], temp_df["rsi28"] = rsi14, rsi28

    # RCIの算出
    rci9 = RCI(close, timeperiod=9)
    rci26 = RCI(close, timeperiod=26)
    temp_df["rci9"], temp_df["rci26"] = rci9, rci26

    # NaNをNoneに変換
    temp_df = temp_df.replace({np.nan: None})

    # DBに保存するデータフレームの差分
    insert_df = temp_df[~temp_df.index.isin(df["date"])]

    # データベースに保存
    try:
        session.bulk_insert_mappings(
            StockModel,
            [
                {"date": index, **{k.lower(): v for k, v in row.to_dict().items()}}
                for index, row in insert_df.iterrows()
            ],
        )
        session.commit()  # コミットしてデータを保存
    except SQLAlchemyError as e:
        session.rollback()
        print(f"Error occurred: {e}")
    finally:
        session.close()


def save_stock() -> None:
    """
    各銘柄の株価データ取得処理を実行
    """
    db_manager = DatabaseManager()
    session = db_manager.get_session()

    # stocksテーブル 全件取得
    stock = get_stock_class(db_manager.Base)
    stocks = session.query(stock).all()
    for stock in stocks:
        get_stock(db_manager, stock.code)
        time.sleep(10)


if __name__ == "__main__":
    save_stock()
