From 627b2326df0af7ef7565616cf62bf4487b4021d8 Mon Sep 17 00:00:00 2001 From: epistemophiliac Date: Fri, 19 Jun 2026 01:29:28 -0400 Subject: [PATCH] Add Python strategy engine, parameter optimization, and faster Docker builds. Support builtin and custom generate_signals strategies with SQLite persistence, exhaustive grid scans (VectorBT comb optimization for MA crossover), professional backtest/optimize UI, and split harvester/app requirements with BuildKit pip cache. --- .dockerignore | 13 + Dockerfile | 10 +- Dockerfile.harvester | 7 +- README.md | 29 +- app.py | 471 ++++++++++++++++++++-------- backtest.py | 89 +----- engine.py | 185 +++++++++++ metrics.py | 62 ++++ requirements-app.txt | 6 + requirements-harvester.txt | 5 + requirements.txt | 14 +- strategies/__init__.py | 5 + strategies/builtin/__init__.py | 0 strategies/builtin/ma_crossover.py | 93 ++++++ strategies/builtin/rsi_reversion.py | 79 +++++ strategies/executor.py | 178 +++++++++++ strategies/registry.py | 72 +++++ strategy_db.py | 80 +++-- 18 files changed, 1139 insertions(+), 259 deletions(-) create mode 100644 .dockerignore create mode 100644 engine.py create mode 100644 metrics.py create mode 100644 requirements-app.txt create mode 100644 requirements-harvester.txt create mode 100644 strategies/__init__.py create mode 100644 strategies/builtin/__init__.py create mode 100644 strategies/builtin/ma_crossover.py create mode 100644 strategies/builtin/rsi_reversion.py create mode 100644 strategies/executor.py create mode 100644 strategies/registry.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..fca12c5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +.git +.gitignore +__pycache__ +*.py[cod] +*.egg-info +.venv +venv +.env +data +*.db +.streamlit/secrets.toml +README.md +.agent-tools diff --git a/Dockerfile b/Dockerfile index aeaa7a9..0b2781c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ +# syntax=docker/dockerfile:1 FROM python:3.11-slim-bookworm ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ - PIP_NO_CACHE_DIR=1 \ PARQUET_DIR=/data/parquet \ STRATEGY_DB_PATH=/data/strategies/strategies.db @@ -12,10 +12,12 @@ RUN apt-get update \ && apt-get install -y --no-install-recommends bash curl \ && rm -rf /var/lib/apt/lists/* -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt +COPY requirements-app.txt . +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-app.txt -COPY telemetry.py auth.py strategy_db.py backtest.py app.py sync.py ./ +COPY telemetry.py auth.py strategy_db.py metrics.py engine.py backtest.py app.py sync.py ./ +COPY strategies ./strategies COPY .streamlit /app/.streamlit EXPOSE 8501 diff --git a/Dockerfile.harvester b/Dockerfile.harvester index 14a1726..7aced44 100644 --- a/Dockerfile.harvester +++ b/Dockerfile.harvester @@ -1,8 +1,8 @@ +# syntax=docker/dockerfile:1 FROM python:3.11-slim-bookworm ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ - PIP_NO_CACHE_DIR=1 \ PARQUET_DIR=/data/parquet \ TZ=America/New_York @@ -12,8 +12,9 @@ RUN apt-get update \ && apt-get install -y --no-install-recommends bash cron curl \ && rm -rf /var/lib/apt/lists/* -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt +COPY requirements-harvester.txt . +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-harvester.txt COPY telemetry.py sync.py ./ COPY scripts/harvester-entrypoint.sh /usr/local/bin/harvester-entrypoint.sh diff --git a/README.md b/README.md index ef67505..0000fc8 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,34 @@ streamlit run app.py For full OIDC locally, set `OIDC_CLIENT_SECRET` and register `http://localhost:8501` as a redirect URI in Authentik. -## Manual sync +## Research workflow + +1. **Backtest** โ€” run a single parameter set and inspect equity curve, drawdown, trades. +2. **Optimize** โ€” exhaustive parameter scan (grid search) ranked by Sharpe, Sortino, return, or drawdown. +3. **Python** โ€” view builtin source or author custom strategies with `generate_signals()`. +4. **Library** โ€” save/load strategies per user (SQLite), including custom Python source code. + +### Custom Python strategy contract + +```python +PARAM_GRID = {"fast_window": list(range(10, 41, 5)), "slow_window": list(range(50, 151, 10))} +DEFAULT_PARAMS = {"fast_window": 20, "slow_window": 50} + +def generate_signals(close, high, low, volume, **params): + # return boolean entry/exit Series aligned to close + return entries, exits +``` + +Builtins: `ma_crossover` (vectorized VectorBT comb scan), `rsi_reversion` (grid scan). + +## Docker build speed + +- Harvester image installs only `requirements-harvester.txt` (no VectorBT/Streamlit). +- App image uses BuildKit pip cache (`RUN --mount=type=cache`). +- `.dockerignore` keeps git/cache out of build context. + +Enable BuildKit on Coolify/build host for cache mounts. + ```bash python sync.py --seed # full history diff --git a/app.py b/app.py index bf2b844..946b5a2 100644 --- a/app.py +++ b/app.py @@ -1,183 +1,404 @@ -"""QuantTrade Streamlit dashboard.""" +"""QuantTrade research workstation.""" from __future__ import annotations import os import pandas as pd +import plotly.express as px import plotly.graph_objects as go import streamlit as st from plotly.subplots import make_subplots from auth import get_current_user, logout -from backtest import load_ohlcv, run_ma_crossover +from engine import ( + get_strategy_source, + load_ohlcv, + run_backtest, + run_optimization, +) +from strategies.executor import CUSTOM_TEMPLATE +from strategies.registry import list_builtins from strategy_db import delete_strategy, init_db, list_strategies, load_strategy, save_strategy from telemetry import capture_exception, init_telemetry init_telemetry("quant-streamlit") init_db() -st.set_page_config( - page_title="QuantTrade", - page_icon="๐Ÿ“ˆ", - layout="wide", -) +st.set_page_config(page_title="QuantTrade", page_icon="๐Ÿ“ˆ", layout="wide") -DEFAULT_TICKERS = os.environ.get( - "CORE_TICKERS", - "SPY,QQQ,AAPL,MSFT,GOOGL,AMZN,NVDA,META,IWM,TLT", -).split(",") +DEFAULT_TICKERS = [ + t.strip().upper() + for t in os.environ.get( + "CORE_TICKERS", + "SPY,QQQ,AAPL,MSFT,GOOGL,AMZN,NVDA,META,IWM,TLT", + ).split(",") + if t.strip() +] + +METRICS = { + "sharpe_ratio": "Sharpe Ratio", + "sortino_ratio": "Sortino Ratio", + "total_return": "Total Return", + "max_drawdown": "Max Drawdown (minimize)", +} -def render_equity_chart(result) -> None: +def sidebar_account(user: str) -> None: + st.subheader("Account") + st.write(f"**{user}**") + if st.button("Logout", use_container_width=True): + logout() + st.rerun() + + +def sidebar_market() -> tuple[str, float, float]: + st.subheader("Market") + ticker = st.selectbox("Ticker", options=DEFAULT_TICKERS) + init_cash = st.number_input("Initial capital ($)", min_value=1000.0, value=10_000.0, step=1000.0) + fees = st.number_input("Fees (per trade, fraction)", min_value=0.0, max_value=0.05, value=0.001, step=0.0005) + return ticker, init_cash, fees + + +def sidebar_strategy_picker() -> str: + st.subheader("Strategy") + builtin_options = {b.key: b.display_name for b in list_builtins()} + kind = st.radio("Type", options=["Built-in", "Custom Python"], horizontal=True) + if kind == "Built-in": + return st.selectbox( + "Model", + options=list(builtin_options.keys()), + format_func=lambda k: builtin_options[k], + ) + if "custom_code" not in st.session_state: + st.session_state.custom_code = CUSTOM_TEMPLATE + return "custom" + + +def render_metrics_row(result) -> None: + c1, c2, c3, c4, c5, c6 = st.columns(6) + c1.metric("Sharpe", f"{result.sharpe_ratio:.2f}") + c2.metric("Sortino", f"{result.sortino_ratio:.2f}") + c3.metric("Return", f"{result.total_return:.1%}") + c4.metric("Max DD", f"{result.max_drawdown:.1%}") + c5.metric("Win rate", f"{result.win_rate:.1%}") + c6.metric("Trades", f"{result.total_trades:,}") + + +def render_backtest_chart(result) -> None: fig = make_subplots( - rows=2, + rows=3, cols=1, shared_xaxes=True, - vertical_spacing=0.08, - row_heights=[0.65, 0.35], - subplot_titles=(f"{result.ticker} Price", "Strategy Equity"), + vertical_spacing=0.05, + row_heights=[0.5, 0.25, 0.25], + subplot_titles=(f"{result.ticker} โ€” price & signals", "Equity curve", "Position"), ) - fig.add_trace( - go.Scatter(x=result.price.index, y=result.price.values, name="Close", line=dict(color="#60a5fa")), + go.Scatter(x=result.price.index, y=result.price, name="Close", line=dict(color="#60a5fa")), + row=1, + col=1, + ) + buys = result.entries & ~result.entries.shift(1, fill_value=False) + sells = result.exits & ~result.exits.shift(1, fill_value=False) + fig.add_trace( + go.Scatter( + x=result.price.index[buys], + y=result.price[buys], + mode="markers", + name="Entry", + marker=dict(color="#34d399", size=8, symbol="triangle-up"), + ), row=1, col=1, ) fig.add_trace( go.Scatter( - x=result.equity_curve.index, - y=result.equity_curve.values, - name="Equity", - line=dict(color="#34d399"), + x=result.price.index[sells], + y=result.price[sells], + mode="markers", + name="Exit", + marker=dict(color="#f87171", size=8, symbol="triangle-down"), ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter(x=result.equity_curve.index, y=result.equity_curve, name="Equity", line=dict(color="#a78bfa")), row=2, col=1, ) - fig.update_layout(height=640, template="plotly_dark", margin=dict(l=20, r=20, t=40, b=20)) + position = result.entries.astype(int).replace(0, -1).cumsum().clip(lower=0) + fig.add_trace( + go.Scatter(x=position.index, y=position, name="In market", fill="tozeroy", line=dict(color="#22d3ee")), + row=3, + col=1, + ) + fig.update_layout(height=760, template="plotly_dark", margin=dict(l=12, r=12, t=40, b=12), showlegend=False) st.plotly_chart(fig, use_container_width=True) +def render_heatmap(results: pd.DataFrame, x: str, y: str, metric: str) -> None: + if x not in results.columns or y not in results.columns: + return + pivot = results.pivot_table(index=y, columns=x, values="score", aggfunc="mean") + fig = px.imshow( + pivot, + labels=dict(x=x, y=y, color="Score"), + color_continuous_scale="Viridis", + aspect="auto", + title=f"Parameter surface โ€” {METRICS.get(metric, metric)}", + ) + fig.update_layout(template="plotly_dark", height=420) + st.plotly_chart(fig, use_container_width=True) + + +def tab_backtest(user: str, ticker: str, init_cash: float, fees: float, strategy_key: str, source_code: str) -> None: + st.markdown("### Single run") + st.caption("Validate one parameter set before running a full scan.") + + params: dict = {"init_cash": init_cash, "fees": fees} + if strategy_key != "custom": + builtin = next(b for b in list_builtins() if b.key == strategy_key) + st.info(builtin.description) + for key, default in builtin.default_params.items(): + if isinstance(default, int): + params[key] = st.number_input(key, value=int(default), step=1) + else: + params[key] = st.number_input(key, value=float(default)) + else: + params.update(st.session_state.get("custom_defaults", {})) + + if st.button("Run backtest", type="primary"): + try: + load_ohlcv(ticker) + result = run_backtest( + ticker=ticker, + strategy_key=strategy_key, + params=params, + source_code=source_code if strategy_key == "custom" else None, + ) + st.session_state["last_backtest"] = result + except Exception as exc: + capture_exception(exc) + st.error(str(exc)) + + result = st.session_state.get("last_backtest") + if result and result.ticker == ticker.upper(): + render_metrics_row(result) + st.json(result.params) + render_backtest_chart(result) + + +def tab_optimize(user: str, ticker: str, init_cash: float, fees: float, strategy_key: str, source_code: str) -> None: + st.markdown("### Parameter scan") + st.caption("Exhaustively test parameter combinations and rank by objective.") + + metric = st.selectbox("Objective", options=list(METRICS.keys()), format_func=lambda m: METRICS[m]) + + if strategy_key != "custom": + builtin = next(b for b in list_builtins() if b.key == strategy_key) + st.write("**Search space**") + st.json(builtin.param_grid) + combo_hint = len(builtin.param_grid.get("window_pool", [])) + if strategy_key == "ma_crossover": + n = len(builtin.param_grid["window_pool"]) + combo_hint = n * (n - 1) // 2 + elif strategy_key == "rsi_reversion": + import itertools + + combo_hint = sum( + 1 + for p, os, ob in itertools.product( + builtin.param_grid["rsi_period"], + builtin.param_grid["oversold"], + builtin.param_grid["overbought"], + ) + if os < ob + ) + else: + import itertools + + combo_hint = sum(1 for _ in itertools.product(*builtin.param_grid.values())) + st.write(f"~**{combo_hint:,}** combinations") + else: + st.write("Uses `PARAM_GRID` defined in your Python strategy.") + + if st.button("Run optimization", type="primary"): + with st.spinner("Scanning parameter spaceโ€ฆ"): + try: + load_ohlcv(ticker) + opt = run_optimization( + ticker=ticker, + strategy_key=strategy_key, + metric=metric, + init_cash=init_cash, + fees=fees, + source_code=source_code if strategy_key == "custom" else None, + ) + st.session_state["last_optimization"] = opt + except Exception as exc: + capture_exception(exc) + st.error(str(exc)) + + opt = st.session_state.get("last_optimization") + if opt and opt.ticker == ticker.upper() and opt.strategy_key == strategy_key: + st.success( + f"Tested **{opt.combinations_tested:,}** combinations ยท " + f"Best {METRICS[opt.metric]}: **{opt.best_score:.3f}**" + ) + st.write("**Optimal parameters**") + st.json(opt.best_params) + + top = opt.results.head(25) + st.dataframe( + top.style.format( + { + "score": "{:.3f}", + "sharpe_ratio": "{:.2f}", + "sortino_ratio": "{:.2f}", + "total_return": "{:.1%}", + "max_drawdown": "{:.1%}", + "win_rate": "{:.1%}", + }, + na_rep="โ€”", + ), + use_container_width=True, + height=360, + ) + + if strategy_key == "ma_crossover": + render_heatmap(opt.results, "fast_window", "slow_window", opt.metric) + elif strategy_key == "rsi_reversion": + render_heatmap(opt.results, "rsi_period", "oversold", opt.metric) + + if st.button("Apply best params to backtest"): + st.session_state["apply_best_params"] = opt.best_params + st.toast("Best parameters saved โ€” switch to Backtest tab.") + + +def tab_editor(strategy_key: str, source_code: str) -> str: + st.markdown("### Strategy code") + st.caption( + "Write Python that defines `generate_signals(close, high, low, volume, **params)` " + "returning `(entries, exits)` booleans. Optional: `PARAM_GRID` and `DEFAULT_PARAMS`." + ) + + if strategy_key == "custom": + code = st.text_area("Python strategy", value=source_code, height=420, label_visibility="collapsed") + st.session_state["custom_code"] = code + return code + + st.code(get_strategy_source(strategy_key), language="python") + return "" + + +def tab_library(user: str, ticker: str, strategy_key: str, source_code: str, params: dict) -> None: + st.markdown("### Saved strategies") + saved = list_strategies(user) + names = [s.name for s in saved] + pick = st.selectbox("Load saved", ["โ€”"] + names) + + name = st.text_input("Save as", placeholder="SPY MA sweep v1") + c1, c2 = st.columns(2) + with c1: + if st.button("Save", use_container_width=True): + if not name.strip(): + st.error("Name required.") + else: + save_strategy( + user, + name.strip(), + ticker, + strategy_key, + params, + source_code if strategy_key == "custom" else None, + ) + st.success(f"Saved **{name.strip()}**") + st.rerun() + with c2: + if st.button("Delete", use_container_width=True) and pick != "โ€”": + delete_strategy(user, pick) + st.success(f"Deleted **{pick}**") + st.rerun() + + if pick != "โ€”": + loaded = load_strategy(user, pick) + if loaded and st.button("Apply loaded strategy", type="primary"): + st.session_state.active_strategy_key = loaded.strategy_key + st.session_state.active_ticker = loaded.ticker + st.session_state.active_params = loaded.params + if loaded.source_code: + st.session_state.custom_code = loaded.source_code + st.session_state.apply_best_params = { + k: v for k, v in loaded.params.items() if k not in ("init_cash", "fees") + } + st.rerun() + + if saved: + st.dataframe( + pd.DataFrame( + [ + { + "name": s.name, + "ticker": s.ticker, + "strategy": s.strategy_key, + "updated": s.created_at[:19], + } + for s in saved + ] + ), + use_container_width=True, + hide_index=True, + ) + + def main() -> None: user = get_current_user() if not user: return st.title("QuantTrade") - st.caption("VectorBT backtests on local Parquet market data") + st.caption("Research desk ยท Python strategies ยท VectorBT parameter scans ยท Parquet data") with st.sidebar: - st.subheader("Account") - st.write(f"Signed in as **{user}**") - if st.button("Logout", use_container_width=True): - logout() - st.rerun() - + sidebar_account(user) st.divider() - st.subheader("Strategy") - ticker = st.selectbox( - "Ticker", - options=[t.strip().upper() for t in DEFAULT_TICKERS if t.strip()], - index=0, - ) - fast_window = st.slider("Fast MA", min_value=5, max_value=100, value=20, step=1) - slow_window = st.slider("Slow MA", min_value=20, max_value=250, value=50, step=1) - init_cash = st.number_input("Initial cash", min_value=1000.0, value=10_000.0, step=1000.0) - fees = st.number_input("Fees (fraction)", min_value=0.0, max_value=0.05, value=0.001, step=0.0005) - - run_clicked = st.button("Run Backtest", type="primary", use_container_width=True) - + ticker, init_cash, fees = sidebar_market() + if "active_ticker" in st.session_state: + ticker = st.session_state.active_ticker st.divider() - st.subheader("Saved Strategies") - saved = list_strategies(user) - saved_names = [s.name for s in saved] - selected_name = st.selectbox("Load strategy", options=["โ€”"] + saved_names) + strategy_key = sidebar_strategy_picker() + if "active_strategy_key" in st.session_state: + strategy_key = st.session_state.active_strategy_key - strategy_name = st.text_input("Strategy name", placeholder="My SPY crossover") - col_save, col_delete = st.columns(2) - with col_save: - save_clicked = st.button("Save Strategy", use_container_width=True) - with col_delete: - delete_clicked = st.button("Delete", use_container_width=True) + source_code = st.session_state.get("custom_code", CUSTOM_TEMPLATE) + params: dict = {"init_cash": init_cash, "fees": fees} + if strategy_key != "custom": + builtin = next(b for b in list_builtins() if b.key == strategy_key) + params.update(builtin.default_params) - params = { - "fast_window": fast_window, - "slow_window": slow_window, - "init_cash": init_cash, - "fees": fees, - } + if "active_params" in st.session_state: + params.update(st.session_state.active_params) - if save_clicked: - if not strategy_name.strip(): - st.sidebar.error("Enter a strategy name before saving.") - else: - save_strategy(user, strategy_name.strip(), ticker, params) - st.sidebar.success(f"Saved '{strategy_name.strip()}'.") - st.rerun() + best = st.session_state.pop("apply_best_params", None) + if best: + params.update(best) - if delete_clicked and selected_name != "โ€”": - delete_strategy(user, selected_name) - st.sidebar.success(f"Deleted '{selected_name}'.") - st.rerun() + tab_bt, tab_opt, tab_code, tab_save = st.tabs(["Backtest", "Optimize", "Python", "Library"]) - active_ticker = ticker - active_params = dict(params) + with tab_code: + source_code = tab_editor(strategy_key, source_code) + st.session_state.custom_code = source_code - if selected_name != "โ€”": - loaded = load_strategy(user, selected_name) - if loaded: - active_ticker = loaded.ticker - active_params.update(loaded.params) - st.info(f"Loaded strategy **{loaded.name}** ({loaded.ticker}). Adjust sliders or run.") + with tab_save: + tab_library(user, ticker, strategy_key, source_code, params) - if run_clicked or selected_name != "โ€”": - try: - load_ohlcv(active_ticker) - result = run_ma_crossover( - ticker=active_ticker, - fast_window=int(active_params["fast_window"]), - slow_window=int(active_params["slow_window"]), - init_cash=float(active_params.get("init_cash", init_cash)), - fees=float(active_params.get("fees", fees)), - ) + with tab_bt: + tab_backtest(user, ticker, init_cash, fees, strategy_key, source_code) - c1, c2, c3, c4 = st.columns(4) - c1.metric("Sharpe Ratio", f"{result.sharpe_ratio:.2f}") - c2.metric("Max Drawdown", f"{result.max_drawdown:.1%}") - c3.metric("Total Return", f"{result.total_return:.1%}") - c4.metric("Bars", f"{len(result.price):,}") - - render_equity_chart(result) - - with st.expander("Raw stats"): - st.write( - pd.DataFrame( - { - "Metric": ["Ticker", "Fast MA", "Slow MA", "Sharpe", "Max DD", "Return"], - "Value": [ - result.ticker, - result.fast_window, - result.slow_window, - result.sharpe_ratio, - result.max_drawdown, - result.total_return, - ], - } - ) - ) - except FileNotFoundError: - st.warning( - f"No Parquet data for **{active_ticker}** yet. " - "Wait for the harvester seed job or check container logs." - ) - except ValueError as exc: - st.error(str(exc)) - except Exception as exc: - capture_exception(exc) - st.error("Backtest failed. The error was reported to Bugsink.") - st.exception(exc) - else: - st.info("Configure parameters in the sidebar and click **Run Backtest**.") + with tab_opt: + tab_optimize(user, ticker, init_cash, fees, strategy_key, source_code) if __name__ == "__main__": diff --git a/backtest.py b/backtest.py index ecd0f60..230da4b 100644 --- a/backtest.py +++ b/backtest.py @@ -1,88 +1,5 @@ -"""VectorBT backtest engine reading local Parquet OHLCV data.""" +"""Backward-compatible exports.""" -from __future__ import annotations +from engine import BacktestResult, load_ohlcv, run_backtest -import os -from dataclasses import dataclass -from pathlib import Path - -import pandas as pd -import vectorbt as vbt - - -@dataclass(frozen=True) -class BacktestResult: - ticker: str - fast_window: int - slow_window: int - sharpe_ratio: float - max_drawdown: float - total_return: float - equity_curve: pd.Series - price: pd.Series - - -def parquet_dir() -> Path: - return Path(os.environ.get("PARQUET_DIR", "/data/parquet")) - - -def load_ohlcv(ticker: str) -> pd.DataFrame: - path = parquet_dir() / f"{ticker.upper()}.parquet" - if not path.exists(): - raise FileNotFoundError(f"No Parquet file for {ticker.upper()} at {path}") - - df = pd.read_parquet(path) - if "Date" in df.columns: - df = df.set_index("Date") - df.index = pd.to_datetime(df.index) - df = df.sort_index() - return df - - -def run_ma_crossover( - ticker: str, - fast_window: int, - slow_window: int, - init_cash: float = 10_000.0, - fees: float = 0.001, -) -> BacktestResult: - if fast_window >= slow_window: - raise ValueError("Fast MA window must be smaller than slow MA window") - - ohlcv = load_ohlcv(ticker) - close = ohlcv["Close"].astype(float) - - fast_ma = vbt.MA.run(close, fast_window, short_name="fast") - slow_ma = vbt.MA.run(close, slow_window, short_name="slow") - - entries = fast_ma.ma_crossed_above(slow_ma) - exits = fast_ma.ma_crossed_below(slow_ma) - - portfolio = vbt.Portfolio.from_signals( - close, - entries=entries, - exits=exits, - init_cash=init_cash, - fees=fees, - freq="1D", - ) - - stats = portfolio.stats() - sharpe = float(stats.get("Sharpe Ratio", 0.0) or 0.0) - max_dd = float(stats.get("Max Drawdown [%]", 0.0) or 0.0) / 100.0 - total_return = float(stats.get("Total Return [%]", 0.0) or 0.0) / 100.0 - - equity = portfolio.value() - if isinstance(equity, pd.DataFrame): - equity = equity.iloc[:, 0] - - return BacktestResult( - ticker=ticker.upper(), - fast_window=fast_window, - slow_window=slow_window, - sharpe_ratio=sharpe, - max_drawdown=max_dd, - total_return=total_return, - equity_curve=equity, - price=close, - ) +__all__ = ["BacktestResult", "load_ohlcv", "run_backtest"] diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..62fd38f --- /dev/null +++ b/engine.py @@ -0,0 +1,185 @@ +"""Backtest and optimization engine.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pandas as pd +import vectorbt as vbt + +from strategies.executor import ( + CUSTOM_TEMPLATE, + StrategyError, + load_custom_strategy, + optimize_custom, + run_builtin_signals, + run_custom_signals, +) +from strategies.registry import BuiltinStrategy, get_builtin + + +@dataclass(frozen=True) +class BacktestResult: + ticker: str + strategy_key: str + params: dict[str, Any] + sharpe_ratio: float + sortino_ratio: float + max_drawdown: float + total_return: float + win_rate: float + total_trades: int + equity_curve: pd.Series + price: pd.Series + entries: pd.Series + exits: pd.Series + + +@dataclass(frozen=True) +class OptimizationResult: + ticker: str + strategy_key: str + metric: str + best_params: dict[str, Any] + best_score: float + results: pd.DataFrame + combinations_tested: int + + +def parquet_dir() -> Path: + return Path(os.environ.get("PARQUET_DIR", "/data/parquet")) + + +def load_ohlcv(ticker: str) -> pd.DataFrame: + path = parquet_dir() / f"{ticker.upper()}.parquet" + if not path.exists(): + raise FileNotFoundError(f"No Parquet file for {ticker.upper()} at {path}") + + df = pd.read_parquet(path) + if "Date" in df.columns: + df = df.set_index("Date") + df.index = pd.to_datetime(df.index) + return df.sort_index() + + +from metrics import run_from_signals, safe_float as _safe_float +def _portfolio_from_signals( + close: pd.Series, + entries: pd.Series, + exits: pd.Series, + init_cash: float, + fees: float, +) -> vbt.Portfolio: + return vbt.Portfolio.from_signals( + close, + entries=entries, + exits=exits, + init_cash=init_cash, + fees=fees, + freq="1D", + ) + + +def run_backtest( + ticker: str, + strategy_key: str, + params: dict[str, Any], + source_code: str | None = None, + init_cash: float = 10_000.0, + fees: float = 0.001, +) -> BacktestResult: + ohlcv = load_ohlcv(ticker) + close = ohlcv["Close"].astype(float) + + runtime_params = {k: v for k, v in params.items() if k not in ("init_cash", "fees")} + init_cash = float(params.get("init_cash", init_cash)) + fees = float(params.get("fees", fees)) + + if strategy_key == "custom": + if not source_code: + raise StrategyError("Custom strategy requires source_code") + entries, exits, _, merged = run_custom_signals(source_code, ohlcv, runtime_params) + params = merged + else: + builtin = get_builtin(strategy_key) + entries, exits = run_builtin_signals(builtin, ohlcv, runtime_params) + params = {**builtin.default_params, **runtime_params} + + portfolio = _portfolio_from_signals(close, entries, exits, init_cash, fees) + stats = portfolio.stats() + equity = portfolio.value() + if isinstance(equity, pd.DataFrame): + equity = equity.iloc[:, 0] + + return BacktestResult( + ticker=ticker.upper(), + strategy_key=strategy_key, + params=params, + sharpe_ratio=_safe_float(stats.get("Sharpe Ratio")), + sortino_ratio=_safe_float(stats.get("Sortino Ratio")), + max_drawdown=_safe_float(stats.get("Max Drawdown [%]")) / 100.0, + total_return=_safe_float(stats.get("Total Return [%]")) / 100.0, + win_rate=_safe_float(stats.get("Win Rate [%]")) / 100.0, + total_trades=int(stats.get("Total Trades", 0) or 0), + equity_curve=equity, + price=close, + entries=entries, + exits=exits, + ) + + +def run_optimization( + ticker: str, + strategy_key: str, + metric: str = "sharpe_ratio", + init_cash: float = 10_000.0, + fees: float = 0.001, + source_code: str | None = None, + param_grid: dict | None = None, +) -> OptimizationResult: + ohlcv = load_ohlcv(ticker) + close = ohlcv["Close"].astype(float) + + if strategy_key == "custom": + if not source_code: + raise StrategyError("Custom strategy requires source_code") + results = optimize_custom( + source_code, + close, + ohlcv, + init_cash=init_cash, + fees=fees, + metric=metric, + param_grid=param_grid, + ) + else: + builtin = get_builtin(strategy_key) + results = builtin.optimize(close, init_cash, fees, metric, grid_override=param_grid) + + if results.empty: + raise StrategyError("Optimization produced no valid parameter combinations.") + + best = results.iloc[0] + param_cols = [c for c in results.columns if c not in { + "score", "sharpe_ratio", "sortino_ratio", "max_drawdown", "total_return", "win_rate", "total_trades", + }] + best_params = {col: best[col] for col in param_cols} + + return OptimizationResult( + ticker=ticker.upper(), + strategy_key=strategy_key, + metric=metric, + best_params=best_params, + best_score=float(best["score"]), + results=results, + combinations_tested=len(results), + ) + + +def get_strategy_source(strategy_key: str, source_code: str | None = None) -> str: + if strategy_key == "custom": + return source_code or CUSTOM_TEMPLATE + return get_builtin(strategy_key).source_code diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..5f17284 --- /dev/null +++ b/metrics.py @@ -0,0 +1,62 @@ +"""Portfolio metric helpers shared by engine and optimizers.""" + +from __future__ import annotations + +from typing import Any + +import pandas as pd +import vectorbt as vbt + + +def safe_float(value: Any) -> float: + try: + if value is None or (isinstance(value, float) and value != value): + return 0.0 + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def run_from_signals( + close: pd.Series, + entries: pd.Series, + exits: pd.Series, + init_cash: float, + fees: float, + params: dict[str, Any], + metric: str = "sharpe_ratio", +) -> dict[str, Any]: + portfolio = vbt.Portfolio.from_signals( + close, + entries=entries, + exits=exits, + init_cash=init_cash, + fees=fees, + freq="1D", + ) + stats = portfolio.stats() + + sharpe = safe_float(stats.get("Sharpe Ratio")) + sortino = safe_float(stats.get("Sortino Ratio")) + max_dd = safe_float(stats.get("Max Drawdown [%]")) / 100.0 + total_return = safe_float(stats.get("Total Return [%]")) / 100.0 + win_rate = safe_float(stats.get("Win Rate [%]")) / 100.0 + total_trades = int(stats.get("Total Trades", 0) or 0) + + score_map = { + "sharpe_ratio": sharpe, + "sortino_ratio": sortino, + "total_return": total_return, + "max_drawdown": -max_dd, + } + + return { + **params, + "sharpe_ratio": sharpe, + "sortino_ratio": sortino, + "max_drawdown": max_dd, + "total_return": total_return, + "win_rate": win_rate, + "total_trades": total_trades, + "score": score_map.get(metric, sharpe), + } diff --git a/requirements-app.txt b/requirements-app.txt new file mode 100644 index 0000000..a56805c --- /dev/null +++ b/requirements-app.txt @@ -0,0 +1,6 @@ +-r requirements-harvester.txt +streamlit>=1.32.0 +vectorbt>=0.26.0,<1.0.0 +plotly>=5.18.0 +authlib>=1.3.0 +requests>=2.31.0 diff --git a/requirements-harvester.txt b/requirements-harvester.txt new file mode 100644 index 0000000..10adb5a --- /dev/null +++ b/requirements-harvester.txt @@ -0,0 +1,5 @@ +yfinance>=0.2.36 +pandas>=2.1.0 +numpy>=1.26.0 +pyarrow>=15.0.0 +sentry-sdk>=2.0.0 diff --git a/requirements.txt b/requirements.txt index 970c39a..21606ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,4 @@ -streamlit>=1.32.0 -vectorbt>=0.26.0 -yfinance>=0.2.36 -pandas>=2.1.0 -numpy>=1.26.0 -pyarrow>=15.0.0 -plotly>=5.18.0 -sentry-sdk>=2.0.0 -authlib>=1.3.0 -requests>=2.31.0 +# App image: pip install -r requirements-app.txt +# Harvester image: pip install -r requirements-harvester.txt (no VectorBT/Streamlit) + +-r requirements-app.txt diff --git a/strategies/__init__.py b/strategies/__init__.py new file mode 100644 index 0000000..57d46f6 --- /dev/null +++ b/strategies/__init__.py @@ -0,0 +1,5 @@ +"""Strategy registry and execution.""" + +from strategies.registry import BUILTIN_STRATEGIES, get_builtin, list_builtins + +__all__ = ["BUILTIN_STRATEGIES", "get_builtin", "list_builtins"] diff --git a/strategies/builtin/__init__.py b/strategies/builtin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strategies/builtin/ma_crossover.py b/strategies/builtin/ma_crossover.py new file mode 100644 index 0000000..41c8a27 --- /dev/null +++ b/strategies/builtin/ma_crossover.py @@ -0,0 +1,93 @@ +"""Moving-average crossover โ€” predefined Python strategy.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import vectorbt as vbt + +STRATEGY_KEY = "ma_crossover" +DISPLAY_NAME = "MA Crossover" +DESCRIPTION = "Enter when fast MA crosses above slow MA; exit on cross below." + +PARAM_GRID = { + "window_pool": list(range(5, 101, 5)), +} + +DEFAULT_PARAMS = { + "fast_window": 20, + "slow_window": 50, +} + + +def generate_signals( + close: pd.Series, + high: pd.Series, + low: pd.Series, + volume: pd.Series, + fast_window: int = 20, + slow_window: int = 50, + **_kwargs, +) -> tuple[pd.Series, pd.Series]: + if fast_window >= slow_window: + raise ValueError("fast_window must be smaller than slow_window") + + fast_ma = vbt.MA.run(close, fast_window).ma + slow_ma = vbt.MA.run(close, slow_window).ma + entries = fast_ma.vbt.crossed_above(slow_ma).fillna(False) + exits = fast_ma.vbt.crossed_below(slow_ma).fillna(False) + return entries, exits + + +def optimize_vectorized( + close: pd.Series, + window_pool: list[int] | None = None, + init_cash: float = 10_000.0, + fees: float = 0.001, + metric: str = "sharpe_ratio", +) -> pd.DataFrame: + """VectorBT combinatorial scan across all fast/slow pairs (fast < slow).""" + pool = np.array(window_pool or PARAM_GRID["window_pool"], dtype=int) + fast_ma, slow_ma = vbt.MA.run_combs(close, pool, r=2, short_names=["fast", "slow"]) + entries = fast_ma.ma_crossed_above(slow_ma) + exits = fast_ma.ma_crossed_below(slow_ma) + + portfolio = vbt.Portfolio.from_signals( + close, + entries=entries, + exits=exits, + init_cash=init_cash, + fees=fees, + freq="1D", + ) + + metric_fn = { + "sharpe_ratio": portfolio.sharpe_ratio, + "sortino_ratio": portfolio.sortino_ratio, + "total_return": portfolio.total_return, + "max_drawdown": lambda **_: portfolio.max_drawdown(group_by=False), + }.get(metric, portfolio.sharpe_ratio) + + scores = metric_fn(group_by=False) + max_dd = portfolio.max_drawdown(group_by=False) + total_ret = portfolio.total_return(group_by=False) + sortino = portfolio.sortino_ratio(group_by=False) + trades = portfolio.trades.count(group_by=False) + + rows = [] + for idx, score in scores.items(): + fast_w, slow_w = int(idx[0]), int(idx[1]) + rows.append( + { + "fast_window": fast_w, + "slow_window": slow_w, + "score": float(score) if score == score else float("nan"), + "sharpe_ratio": float(scores[idx]) if scores[idx] == scores[idx] else float("nan"), + "sortino_ratio": float(sortino[idx]) if sortino[idx] == sortino[idx] else float("nan"), + "total_return": float(total_ret[idx]), + "max_drawdown": float(max_dd[idx]), + "total_trades": int(trades[idx]), + } + ) + + return pd.DataFrame(rows).sort_values("score", ascending=False, na_position="last") diff --git a/strategies/builtin/rsi_reversion.py b/strategies/builtin/rsi_reversion.py new file mode 100644 index 0000000..2677fca --- /dev/null +++ b/strategies/builtin/rsi_reversion.py @@ -0,0 +1,79 @@ +"""RSI mean-reversion โ€” predefined Python strategy.""" + +from __future__ import annotations + +import itertools + +import pandas as pd +import vectorbt as vbt + +STRATEGY_KEY = "rsi_reversion" +DISPLAY_NAME = "RSI Mean Reversion" +DESCRIPTION = "Buy when RSI is oversold; sell when RSI is overbought." + +PARAM_GRID = { + "rsi_period": list(range(7, 22, 2)), + "oversold": list(range(20, 36, 5)), + "overbought": list(range(65, 81, 5)), +} + +DEFAULT_PARAMS = { + "rsi_period": 14, + "oversold": 30, + "overbought": 70, +} + + +def generate_signals( + close: pd.Series, + high: pd.Series, + low: pd.Series, + volume: pd.Series, + rsi_period: int = 14, + oversold: float = 30, + overbought: float = 70, + **_kwargs, +) -> tuple[pd.Series, pd.Series]: + if oversold >= overbought: + raise ValueError("oversold must be less than overbought") + + rsi = vbt.RSI.run(close, window=rsi_period).rsi + entries = (rsi < oversold).fillna(False) + exits = (rsi > overbought).fillna(False) + return entries, exits + + +def optimize_grid( + close: pd.Series, + param_grid: dict | None = None, + init_cash: float = 10_000.0, + fees: float = 0.001, + metric: str = "sharpe_ratio", +) -> pd.DataFrame: + """Exhaustive grid over RSI parameter space.""" + from metrics import run_from_signals + + grid = param_grid or PARAM_GRID + keys = list(grid.keys()) + rows = [] + + for values in itertools.product(*(grid[k] for k in keys)): + params = dict(zip(keys, values)) + if params["oversold"] >= params["overbought"]: + continue + entries, exits = generate_signals(close, close, close, close, **params) + result = run_from_signals( + close=close, + entries=entries, + exits=exits, + init_cash=init_cash, + fees=fees, + params=params, + metric=metric, + ) + rows.append(result) + + frame = pd.DataFrame(rows) + if frame.empty: + return frame + return frame.sort_values("score", ascending=False, na_position="last") diff --git a/strategies/executor.py b/strategies/executor.py new file mode 100644 index 0000000..ff9fbf3 --- /dev/null +++ b/strategies/executor.py @@ -0,0 +1,178 @@ +"""Execute builtin and user-authored Python strategies.""" + +from __future__ import annotations + +import itertools +from typing import Any + +import numpy as np +import pandas as pd +import vectorbt as vbt + +from strategies.registry import BuiltinStrategy, get_builtin + +SAFE_GLOBALS: dict[str, Any] = { + "__builtins__": { + "range": range, + "len": len, + "min": min, + "max": max, + "abs": abs, + "float": float, + "int": int, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "zip": zip, + "enumerate": enumerate, + "sum": sum, + "round": round, + }, + "np": np, + "pd": pd, + "vbt": vbt, +} + +CUSTOM_TEMPLATE = '''import pandas as pd +import vectorbt as vbt + +# Optional: define PARAM_GRID for optimization scans +PARAM_GRID = { + "fast_window": list(range(10, 41, 5)), + "slow_window": list(range(50, 151, 10)), +} + +DEFAULT_PARAMS = { + "fast_window": 20, + "slow_window": 50, +} + + +def generate_signals(close, high, low, volume, **params): + """Return (entries, exits) as boolean Series aligned to close.""" + fast_w = int(params.get("fast_window", 20)) + slow_w = int(params.get("slow_window", 50)) + if fast_w >= slow_w: + raise ValueError("fast_window must be < slow_window") + + fast_ma = vbt.MA.run(close, fast_w, short_name="fast") + slow_ma = vbt.MA.run(close, slow_w, short_name="slow") + entries = fast_ma.ma_crossed_above(slow_ma).fillna(False) + exits = fast_ma.ma_crossed_below(slow_ma).fillna(False) + return entries, exits +''' + + +class StrategyError(ValueError): + pass + + +def load_custom_strategy(source_code: str) -> tuple[Any, dict, dict]: + namespace: dict[str, Any] = {} + try: + exec(source_code, SAFE_GLOBALS, namespace) + except Exception as exc: + raise StrategyError(f"Strategy compile error: {exc}") from exc + + generate = namespace.get("generate_signals") + if not callable(generate): + raise StrategyError("Custom strategy must define generate_signals(close, high, low, volume, **params)") + + param_grid = namespace.get("PARAM_GRID", {}) + default_params = namespace.get("DEFAULT_PARAMS", {}) + if not isinstance(param_grid, dict): + raise StrategyError("PARAM_GRID must be a dict of param -> list of values") + if not isinstance(default_params, dict): + raise StrategyError("DEFAULT_PARAMS must be a dict") + + return generate, param_grid, default_params + + +def run_builtin_signals( + builtin: BuiltinStrategy, + ohlcv: pd.DataFrame, + params: dict[str, Any], +) -> tuple[pd.Series, pd.Series]: + close = ohlcv["Close"].astype(float) + high = ohlcv.get("High", close).astype(float) + low = ohlcv.get("Low", close).astype(float) + volume = ohlcv.get("Volume", pd.Series(0, index=close.index)).astype(float) + entries, exits = builtin.generate_signals(close, high, low, volume, **params) + return _coerce_signals(entries, exits, close.index) + + +def run_custom_signals( + source_code: str, + ohlcv: pd.DataFrame, + params: dict[str, Any], +) -> tuple[pd.Series, pd.Series, dict, dict]: + generate, param_grid, defaults = load_custom_strategy(source_code) + merged = {**defaults, **params} + close = ohlcv["Close"].astype(float) + high = ohlcv.get("High", close).astype(float) + low = ohlcv.get("Low", close).astype(float) + volume = ohlcv.get("Volume", pd.Series(0, index=close.index)).astype(float) + entries, exits = generate(close, high, low, volume, **merged) + return _coerce_signals(entries, exits, close.index), param_grid, merged + + +def _coerce_signals(entries, exits, index: pd.Index) -> tuple[pd.Series, pd.Series]: + e = pd.Series(entries, index=index).fillna(False).astype(bool) + x = pd.Series(exits, index=index).fillna(False).astype(bool) + return e, x + + +def optimize_custom( + source_code: str, + close: pd.Series, + ohlcv: pd.DataFrame, + init_cash: float, + fees: float, + metric: str, + param_grid: dict | None = None, + max_combos: int = 2_500, +) -> pd.DataFrame: + from metrics import run_from_signals + + generate, grid, defaults = load_custom_strategy(source_code) + grid = param_grid or grid + if not grid: + raise StrategyError("Define PARAM_GRID in your strategy to run optimization.") + + keys = list(grid.keys()) + combos = list(itertools.product(*(grid[k] for k in keys))) + if len(combos) > max_combos: + raise StrategyError( + f"Grid has {len(combos):,} combinations (max {max_combos:,}). " + "Widen step sizes or narrow ranges in PARAM_GRID." + ) + + high = ohlcv.get("High", close).astype(float) + low = ohlcv.get("Low", close).astype(float) + volume = ohlcv.get("Volume", pd.Series(0, index=close.index)).astype(float) + + rows = [] + for values in combos: + params = {**defaults, **dict(zip(keys, values))} + try: + entries, exits = generate(close, high, low, volume, **params) + entries, exits = _coerce_signals(entries, exits, close.index) + rows.append( + run_from_signals( + close=close, + entries=entries, + exits=exits, + init_cash=init_cash, + fees=fees, + params=params, + metric=metric, + ) + ) + except Exception: + continue + + frame = pd.DataFrame(rows) + if frame.empty: + return frame + return frame.sort_values("score", ascending=False, na_position="last") diff --git a/strategies/registry.py b/strategies/registry.py new file mode 100644 index 0000000..62afe03 --- /dev/null +++ b/strategies/registry.py @@ -0,0 +1,72 @@ +"""Builtin strategy registry.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +from strategies.builtin import ma_crossover, rsi_reversion + +SignalFn = Callable[..., tuple[Any, Any]] + + +@dataclass(frozen=True) +class BuiltinStrategy: + key: str + display_name: str + description: str + module: Any + generate_signals: SignalFn + default_params: dict[str, Any] + param_grid: dict[str, list[Any]] + source_code: str + + def optimize(self, close, init_cash: float, fees: float, metric: str, grid_override: dict | None = None): + if self.key == "ma_crossover": + pool = (grid_override or {}).get("window_pool", self.param_grid.get("window_pool")) + return self.module.optimize_vectorized( + close, + window_pool=pool, + init_cash=init_cash, + fees=fees, + metric=metric, + ) + if self.key == "rsi_reversion": + return self.module.optimize_grid( + close, + param_grid=grid_override or self.param_grid, + init_cash=init_cash, + fees=fees, + metric=metric, + ) + raise NotImplementedError(f"No optimizer for {self.key}") + + +def _register(module) -> BuiltinStrategy: + return BuiltinStrategy( + key=module.STRATEGY_KEY, + display_name=module.DISPLAY_NAME, + description=module.DESCRIPTION, + module=module, + generate_signals=module.generate_signals, + default_params=dict(module.DEFAULT_PARAMS), + param_grid=dict(module.PARAM_GRID), + source_code=inspect.getsource(module), + ) + + +BUILTIN_STRATEGIES: dict[str, BuiltinStrategy] = { + ma_crossover.STRATEGY_KEY: _register(ma_crossover), + rsi_reversion.STRATEGY_KEY: _register(rsi_reversion), +} + + +def list_builtins() -> list[BuiltinStrategy]: + return list(BUILTIN_STRATEGIES.values()) + + +def get_builtin(key: str) -> BuiltinStrategy: + if key not in BUILTIN_STRATEGIES: + raise KeyError(f"Unknown builtin strategy: {key}") + return BUILTIN_STRATEGIES[key] diff --git a/strategy_db.py b/strategy_db.py index 075ab9d..5a43af0 100644 --- a/strategy_db.py +++ b/strategy_db.py @@ -1,4 +1,4 @@ -"""SQLite persistence for user-saved strategies.""" +"""SQLite persistence for user-saved strategies (builtin + custom Python).""" from __future__ import annotations @@ -17,7 +17,9 @@ class SavedStrategy: username: str name: str ticker: str + strategy_key: str params: dict[str, Any] + source_code: str | None created_at: str @@ -25,6 +27,16 @@ def _db_path() -> str: return os.environ.get("STRATEGY_DB_PATH", "/data/strategies/strategies.db") +def _migrate(conn: sqlite3.Connection) -> None: + cols = {row[1] for row in conn.execute("PRAGMA table_info(strategies)")} + if "strategy_key" not in cols: + conn.execute( + "ALTER TABLE strategies ADD COLUMN strategy_key TEXT NOT NULL DEFAULT 'ma_crossover'" + ) + if "source_code" not in cols: + conn.execute("ALTER TABLE strategies ADD COLUMN source_code TEXT") + + def init_db() -> None: path = _db_path() os.makedirs(os.path.dirname(path), exist_ok=True) @@ -36,12 +48,15 @@ def init_db() -> None: username TEXT NOT NULL, name TEXT NOT NULL, ticker TEXT NOT NULL, + strategy_key TEXT NOT NULL DEFAULT 'ma_crossover', params_json TEXT NOT NULL, + source_code TEXT, created_at TEXT NOT NULL, UNIQUE(username, name) ) """ ) + _migrate(conn) conn.commit() @@ -55,24 +70,51 @@ def _connect(): conn.close() +def _row_to_strategy(row: sqlite3.Row) -> SavedStrategy: + return SavedStrategy( + id=row["id"], + username=row["username"], + name=row["name"], + ticker=row["ticker"], + strategy_key=row["strategy_key"] if "strategy_key" in row.keys() else "ma_crossover", + params=json.loads(row["params_json"]), + source_code=row["source_code"] if "source_code" in row.keys() else None, + created_at=row["created_at"], + ) + + def save_strategy( username: str, name: str, ticker: str, + strategy_key: str, params: dict[str, Any], + source_code: str | None = None, ) -> None: created_at = datetime.now(timezone.utc).isoformat() with _connect() as conn: conn.execute( """ - INSERT INTO strategies (username, name, ticker, params_json, created_at) - VALUES (?, ?, ?, ?, ?) + INSERT INTO strategies ( + username, name, ticker, strategy_key, params_json, source_code, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(username, name) DO UPDATE SET ticker = excluded.ticker, + strategy_key = excluded.strategy_key, params_json = excluded.params_json, + source_code = excluded.source_code, created_at = excluded.created_at """, - (username, name.strip(), ticker.upper(), json.dumps(params), created_at), + ( + username, + name.strip(), + ticker.upper(), + strategy_key, + json.dumps(params), + source_code, + created_at, + ), ) conn.commit() @@ -81,49 +123,27 @@ def list_strategies(username: str) -> list[SavedStrategy]: with _connect() as conn: rows = conn.execute( """ - SELECT id, username, name, ticker, params_json, created_at + SELECT id, username, name, ticker, strategy_key, params_json, source_code, created_at FROM strategies WHERE username = ? ORDER BY created_at DESC """, (username,), ).fetchall() - - return [ - SavedStrategy( - id=row["id"], - username=row["username"], - name=row["name"], - ticker=row["ticker"], - params=json.loads(row["params_json"]), - created_at=row["created_at"], - ) - for row in rows - ] + return [_row_to_strategy(row) for row in rows] def load_strategy(username: str, name: str) -> SavedStrategy | None: with _connect() as conn: row = conn.execute( """ - SELECT id, username, name, ticker, params_json, created_at + SELECT id, username, name, ticker, strategy_key, params_json, source_code, created_at FROM strategies WHERE username = ? AND name = ? """, (username, name), ).fetchone() - - if row is None: - return None - - return SavedStrategy( - id=row["id"], - username=row["username"], - name=row["name"], - ticker=row["ticker"], - params=json.loads(row["params_json"]), - created_at=row["created_at"], - ) + return _row_to_strategy(row) if row else None def delete_strategy(username: str, name: str) -> None: