vdkch/main.py

348 lines
12 KiB
Python
Raw Normal View History

2025-02-10 15:53:52 +03:00
# main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional
import uuid
import os
from datetime import datetime
import requests
from qdrant_client import QdrantClient
from qdrant_client.http import models
from fastapi.middleware.cors import CORSMiddleware
2025-02-15 13:52:05 +03:00
import time # Добавлен импорт модуля time
2025-02-10 15:53:52 +03:00
# Database setup
from sqlalchemy import create_engine, Column, String, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# Image processing
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
import logging
from qdrant_client.http.models import VectorParams, Distance
2025-02-15 13:52:05 +03:00
# Настройка базового логирования
2025-02-10 15:53:52 +03:00
logging.basicConfig(
2025-02-15 13:52:05 +03:00
level=logging.INFO,
2025-02-10 15:53:52 +03:00
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
2025-02-15 13:52:05 +03:00
# Конфигурация
2025-02-10 15:53:52 +03:00
COLLECTION_NAME = "posts"
2025-02-15 13:52:05 +03:00
VECTOR_SIZE = 1280
2025-02-10 15:53:52 +03:00
DATABASE_URL = "sqlite:///./imageboard.db"
2025-02-16 18:07:38 +03:00
QDRANT_URL = "http://qdrant:6333"
OLLAMA_URL = "http://ollama:11434"
2025-02-15 13:52:05 +03:00
EMBEDDING_MODEL = "nomic-embed-text"
IMAGE_MODEL = "openai/clip-vit-base-patch32"
2025-02-10 15:53:52 +03:00
IMAGE_SIZE = (224, 224)
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
2025-02-15 13:52:05 +03:00
# Инициализация компонентов
2025-02-10 15:53:52 +03:00
Base = declarative_base()
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
2025-02-15 13:52:05 +03:00
# Инициализация CLIP
2025-02-10 15:53:52 +03:00
clip_model = CLIPModel.from_pretrained(IMAGE_MODEL)
clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = clip_model.to(device)
2025-02-15 13:52:05 +03:00
# Функция для создания QdrantClient с повторными попытками
def create_qdrant_client():
max_attempts = 5
attempt = 0
while attempt < max_attempts:
try:
client = QdrantClient(QDRANT_URL)
# Проверка подключения
client.get_collections()
logger.info("Успешное подключение к Qdrant")
return client
except Exception as e:
logger.warning(f"Попытка {attempt+1} подключения к Qdrant не удалась: {str(e)}")
attempt += 1
time.sleep(2)
raise RuntimeError(f"Не удалось подключиться к Qdrant после {max_attempts} попыток")
# Инициализация клиента Qdrant
try:
qdrant_client = create_qdrant_client()
except Exception as e:
logger.error(f"Ошибка инициализации Qdrant: {str(e)}")
raise
# Функция проверки и создания коллекции
2025-02-10 15:53:52 +03:00
def ensure_collection_exists():
2025-02-15 13:52:05 +03:00
max_attempts = 5
attempt = 0
while attempt < max_attempts:
try:
# Проверка существования коллекции
qdrant_client.get_collection(collection_name=COLLECTION_NAME)
logger.info(f"Коллекция '{COLLECTION_NAME}' существует")
return
except Exception as e:
if "not found" in str(e).lower():
logger.info(f"Создание коллекции '{COLLECTION_NAME}'...")
try:
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE
)
)
logger.info(f"Коллекция '{COLLECTION_NAME}' создана")
return
except Exception as create_error:
logger.error(f"Ошибка создания: {str(create_error)}")
else:
logger.error(f"Ошибка подключения: {str(e)}")
attempt += 1
time.sleep(2)
raise RuntimeError(f"Не удалось инициализировать коллекцию после {max_attempts} попыток")
# Вызов функции проверки коллекции
try:
ensure_collection_exists()
except Exception as e:
logger.error(f"Ошибка инициализации коллекции: {str(e)}")
raise
# Инициализация FastAPI
2025-02-10 15:53:52 +03:00
app = FastAPI()
2025-02-15 13:52:05 +03:00
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
2025-02-10 15:53:52 +03:00
app.add_middleware(
CORSMiddleware,
2025-02-15 13:52:05 +03:00
allow_origins=["*"],
2025-02-10 15:53:52 +03:00
allow_credentials=True,
2025-02-15 13:52:05 +03:00
allow_methods=["*"],
allow_headers=["*"],
2025-02-10 15:53:52 +03:00
)
2025-02-15 13:52:05 +03:00
# Модель базы данных
2025-02-10 15:53:52 +03:00
class Post(Base):
__tablename__ = "posts"
id = Column(String, primary_key=True, index=True)
text = Column(Text, nullable=True)
image = Column(String, nullable=True)
created_at = Column(DateTime)
Base.metadata.create_all(bind=engine)
# Pydantic model для ответа
class PostResponse(BaseModel):
id: str
text: Optional[str] = None
image: Optional[str] = None
created_at: datetime
vector: Optional[List[float]] = None
class Config:
orm_mode = True
# Pydantic модель для запроса вектора
class VectorQuery(BaseModel):
vector: List[float]
# Utility functions
def generate_text_embedding(text: str) -> List[float]:
response = requests.post(
f"{OLLAMA_URL}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text}
)
if response.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding generation failed")
return response.json()["embedding"]
def generate_image_embedding(image_bytes: bytes) -> List[float]:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = clip_processor(
images=image,
return_tensors="pt",
padding=True
).to(device)
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
return features.cpu().numpy().tolist()[0]
def process_image(image_bytes: bytes) -> bytes:
img = Image.open(io.BytesIO(image_bytes))
img = img.convert("RGB")
img = img.resize(IMAGE_SIZE)
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
return buffer.getvalue()
2025-02-12 00:57:29 +03:00
# API endpoints
2025-02-10 15:53:52 +03:00
# API endpoints
@app.post("/posts/", response_model=PostResponse)
async def create_post(
text: Optional[str] = Form(None),
image: Optional[UploadFile] = File(None)
):
db = SessionLocal()
try:
post_id = str(uuid.uuid4())
image_path = None
2025-02-12 00:57:29 +03:00
thumbnail_path = None # Placeholder for thumbnail
2025-02-10 15:53:52 +03:00
embeddings = []
if text:
logger.info("Генерация эмбеддинга для текста")
text_embedding = generate_text_embedding(text)
embeddings.extend(text_embedding)
if image:
logger.info("Обработка изображения")
image_bytes = await image.read()
2025-02-12 00:57:29 +03:00
# Save original image
2025-02-10 15:53:52 +03:00
image_path = f"{UPLOAD_DIR}/{post_id}.jpg"
with open(image_path, "wb") as f:
2025-02-12 00:57:29 +03:00
f.write(image_bytes)
# Create processed image for embeddings
processed_image = process_image(image_bytes) # Assume this resizes/image processing
2025-02-10 15:53:52 +03:00
2025-02-12 00:57:29 +03:00
# Generate thumbnail as placeholder (example implementation)
#thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation
#thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg"
#with open(thumbnail_path, "wb") as f:
# f.write(thumbnail)
# Generate embeddings from processed image
2025-02-10 15:53:52 +03:00
image_embedding = generate_image_embedding(processed_image)
embeddings.extend(image_embedding)
logger.info("Сохранение данных в Qdrant")
qdrant_client.upsert(
collection_name="posts",
points=[models.PointStruct(
id=post_id,
vector=embeddings,
payload={"post_id": post_id}
)]
)
logger.info("Сохранение поста в базу данных")
db_post = Post(
id=post_id,
text=text,
image=image_path,
2025-02-12 00:57:29 +03:00
#thumbnail=thumbnail_path, Add thumbnail field to your Post model
2025-02-10 15:53:52 +03:00
created_at=datetime.now()
)
db.add(db_post)
db.commit()
db.refresh(db_post)
response = PostResponse(
id=db_post.id,
text=db_post.text,
image=db_post.image,
2025-02-12 00:57:29 +03:00
#thumbnail=db_post.thumbnail, Update response model to include thumbnail
2025-02-10 15:53:52 +03:00
created_at=db_post.created_at,
vector=embeddings
)
logger.info("Пост успешно создан: %s", response)
return response
except Exception as e:
db.rollback()
logger.exception("Ошибка при создании поста")
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
2025-02-12 00:57:29 +03:00
2025-02-10 15:53:52 +03:00
@app.get("/search/")
async def search_posts(
text: Optional[str] = None,
image: Optional[UploadFile] = File(None)
):
try:
query_embedding = []
if text:
logger.info("Генерация эмбеддинга для текста (поиск)")
text_embedding = generate_text_embedding(text)
query_embedding.extend(text_embedding)
if image:
logger.info("Генерация эмбеддинга для изображения (поиск)")
image_bytes = await image.read()
processed_image = process_image(image_bytes)
image_embedding = generate_image_embedding(processed_image)
query_embedding.extend(image_embedding)
logger.info("Выполнение поиска в Qdrant")
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query_embedding,
limit=10
)
logger.info("Поиск завершён. Найдено результатов: %d", len(search_results))
return [result.payload for result in search_results]
except Exception as e:
logger.exception("Ошибка при поиске постов")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/posts/", response_model=List[PostResponse])
async def get_all_posts():
db = SessionLocal()
try:
posts = db.query(Post).all()
return posts
finally:
db.close()
# Новый endpoint: получение "древа" постов по вектору пользователя
@app.post("/posts/tree", response_model=List[PostResponse])
async def get_posts_tree(query: VectorQuery):
# Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query.vector,
limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов
)
print("search_results")
print(search_results)
# Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним)
post_ids = [result.payload.get("post_id") for result in search_results]
print("post_ids")
print(post_ids)
db = SessionLocal()
try:
# Получаем все посты из БД по списку ID
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
print("posts")
print(posts)
# Создаём словарь для сохранения соответствия post_id -> post
posts_dict = {post.id: post for post in posts}
print("posts_dict")
print(posts_dict)
# Восстанавливаем порядок, используя список post_ids
ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict]
print("ordered_posts")
print(ordered_posts)
return ordered_posts
finally:
db.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)