from flask import Flask, Response, render_template, g
from typing import Any, Dict, List
import pandas as pd
from sqlalchemy import desc
from sqlalchemy.orm import aliased
from sqlalchemy.exc import SQLAlchemyError
from database import DatabaseManager
from models import create_stock_model
import yaml
from util import clean_string

# Flaskアプリケーションを作成


def create_app():
    app = Flask(__name__)
    app.secret_key = "secret_key_for_flash_messages"

    @app.before_request
    def load_stock_data() -> None:
        """
        各リクエスト前に実行される関数。
        設定ファイル("config.yaml")から株価指数の設定を読み込み、各指数の直近2日間の終値データを
        データベースから取得し、Flaskのグローバルオブジェクト(g)に 'index_data' として保存する。

        - 各指数について、直近2日分のデータが存在すれば、今日の終値とその差分を計算して格納。
        - 取得できない場合は "N/A" もしくは "Error" を設定し、エラー発生時は空の辞書とする。

        Returns:
            None
        """

        db_manager = DatabaseManager()
        g.index_data = {}
        try:
            with open("config.yaml", "r", encoding="utf-8") as file:
                config = yaml.safe_load(file)
                indices = config["indices"]

                # セッションの管理を with ステートメントで実施
                with db_manager.get_session() as session:
                    for index in indices:
                        try:
                            symbol = clean_string(index["symbol"])
                            name = index["name"]
                            # 直近2日間のデータを取得
                            StockModel = create_stock_model(
                                f"stock_{symbol}", db_manager.Base
                            )
                            recent_data = (
                                session.query(StockModel)
                                .order_by(desc(StockModel.date))
                                .limit(2)
                                .all()
                            )

                            if len(recent_data) == 2:
                                today_price = recent_data[0].close
                                yesterday_price = recent_data[1].close
                                price_diff = today_price - yesterday_price
                                # 差分に応じて表示を調整
                                g.index_data[name] = {
                                    "price": f"{today_price:,.2f}",
                                    "diff": round(price_diff, 2),
                                }
                            else:
                                g.index_data[name] = {
                                    "price": "N/A",
                                    "diff": "N/A",
                                }
                        except Exception as e:
                            # インデックス単位のエラー処理
                            g.index_data[index.get("name", "Unknown")] = {
                                "price": "Error",
                                "diff": "Error",
                            }
        except Exception as e:
            g.index_data = {}


    # 全テンプレートに変数を反映
    @app.context_processor
    def load_index() -> Dict[str, Any]:
        """
        Flaskのコンテキストプロセッサ関数。
        グローバルオブジェクト(g)に格納された 'index_data' をテンプレートで使用できるように辞書形式で返却する。

        Returns:
            Dict[str, Any]: テンプレートコンテキストに渡す辞書。'index_data' キーに対して g.index_data の値を持つ。
        """

        return {"index_data": g.index_data}
    
    
    def get_index_data(code: str) -> Dict[str, List[Any]]:
        """
        指定されたコードに対応する指標データの最新12行を取得し、
        JSON形式（辞書形式）に変換して返す関数。

        Args:
            code (str): 指標データのコード（テーブル名の一部として利用）

        Returns:
            Dict[str, List[Any]]: 各カラム名をキー、各カラムのデータリストを値とする辞書。
                                  データ取得に失敗した場合は空の辞書を返す。
        """

        db_manager = DatabaseManager()

        try:
            # with ステートメントでセッションを管理
            with db_manager.get_session() as session:
                session = db_manager.get_session()
                # データ取得
                StockModel = create_stock_model(
                    f"stock_{code}", db_manager.Base)

                # サブクエリとして定義
                last_12_rows = (
                    session.query(StockModel)
                    .order_by(desc(StockModel.date))  # 降順で取得
                    .limit(12)  # 最新12行を取得
                    .subquery()  # サブクエリ化
                )

                # サブクエリにエイリアスを付ける
                aliased_subquery = aliased(StockModel, last_12_rows)

                # 再度昇順にしてORMモデルとして取得
                stock_datas = (
                    session.query(aliased_subquery)
                    .order_by(aliased_subquery.date.asc())
                    .all()
                )

                # stock_datasの各オブジェクトの属性を辞書形式に変換してリストに格納
                dicts = [vars(stock_data) for stock_data in stock_datas]
                # メタ情報を除いてデータフレームに変換
                df = pd.DataFrame(dicts).drop("_sa_instance_state", axis=1)
                # JSON形式に変換
                json_data = df.to_dict(orient="list")
                return json_data
        except SQLAlchemyError as e:
            print(f"Database error occurred: {e}")
            return pd.DataFrame().to_dict(orient="list")  # エラー時に空のJSONを返す

    # ルートの定義
    @app.route("/")
    def index() -> Response:
        """
        ルート ('/') にアクセスされた際に、株価指数のデータを取得し、
        テンプレートに渡してHTMLをレンダリングするエンドポイント

        Returns:
            Response: レンダリングされたHTMLのレスポンス
        """
        n225_chart_data = get_index_data("N225")
        dji_chart_data = get_index_data("DJI")
        return render_template(
            "index.html", n225_chart_data=n225_chart_data, dji_chart_data=dji_chart_data
        )

    return app
